Skip to content

Commit b4beba7

Browse files
authored
Merge branch 'main' into numerical-derivation
2 parents 877634a + 7b6590b commit b4beba7

File tree

8 files changed

+410
-10
lines changed

8 files changed

+410
-10
lines changed

examples/simulate_pneumatic_planar_pcs.py renamed to examples/simulate_pneumatically_actuated_planar_pcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import jsrm
1414
from jsrm import ode_factory
15-
from jsrm.systems import pneumatic_planar_pcs
15+
from jsrm.systems import pneumatically_actuated_planar_pcs
1616

1717
num_segments = 1
1818

@@ -48,7 +48,7 @@
4848
strain_selector = jnp.array([True, False, True])[None, :].repeat(num_segments, axis=0).flatten()
4949

5050
B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
51-
pneumatic_planar_pcs.factory(
51+
pneumatically_actuated_planar_pcs.factory(
5252
num_segments, sym_exp_filepath, strain_selector, # simplified_actuation_mapping=True
5353
)
5454
)
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import cv2 # importing cv2
2+
from functools import partial
3+
import jax
4+
5+
jax.config.update("jax_enable_x64", True) # double precision
6+
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
7+
from jax import Array, vmap
8+
from jax import numpy as jnp
9+
import matplotlib.pyplot as plt
10+
import numpy as onp
11+
from pathlib import Path
12+
from typing import Callable, Dict
13+
14+
import jsrm
15+
from jsrm import ode_factory
16+
from jsrm.systems import tendon_actuated_planar_pcs as planar_pcs
17+
18+
num_segments = 1
19+
20+
# filepath to symbolic expressions
21+
sym_exp_filepath = (
22+
Path(jsrm.__file__).parent
23+
/ "symbolic_expressions"
24+
/ f"planar_pcs_ns-{num_segments}.dill"
25+
)
26+
27+
# set parameters
28+
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
29+
params = {
30+
"th0": jnp.array(0.0), # initial orientation angle [rad]
31+
"l": 1e-1 * jnp.ones((num_segments,)),
32+
"r": 2e-2 * jnp.ones((num_segments,)),
33+
"rho": rho,
34+
"g": jnp.array([0.0, 9.81]),
35+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
36+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
37+
"d": 2e-2 * jnp.array([[1.0, -1.0]]).repeat(num_segments, axis=0), # distance of tendons from the central axis [m]
38+
}
39+
params["D"] = 1e-3 * jnp.diag(
40+
(jnp.repeat(
41+
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
42+
) * params["l"][:, None]).flatten()
43+
)
44+
45+
# activate all strains (i.e. bending, shear, and axial)
46+
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
47+
# actuation selector for the segments
48+
segment_actuation_selector = jnp.ones((num_segments,), dtype=bool)
49+
# segment_actuation_selector = jnp.array([False, True]) # only the last segment is actuated
50+
51+
# define initial configuration
52+
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
53+
# number of generalized coordinates
54+
n_q = q0.shape[0]
55+
56+
# set simulation parameters
57+
dt = 1e-4 # time step
58+
ts = jnp.arange(0.0, 10.0, dt) # time steps
59+
skip_step = 10 # how many time steps to skip in between video frames
60+
video_ts = ts[::skip_step] # time steps for video
61+
62+
# video settings
63+
video_width, video_height = 700, 700 # img height and width
64+
video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4"
65+
66+
67+
def draw_robot(
68+
batched_forward_kinematics_fn: Callable,
69+
params: Dict[str, Array],
70+
q: Array,
71+
width: int,
72+
height: int,
73+
num_points: int = 50,
74+
) -> onp.ndarray:
75+
# plotting in OpenCV
76+
h, w = height, width # img height and width
77+
ppm = h / (2.0 * jnp.sum(params["l"])) # pixel per meter
78+
base_color = (0, 0, 0) # black robot_color in BGR
79+
robot_color = (255, 0, 0) # black robot_color in BGR
80+
81+
# we use for plotting N points along the length of the robot
82+
s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points)
83+
84+
# poses along the robot of shape (3, N)
85+
chi_ps = batched_forward_kinematics_fn(params, q, s_ps)
86+
87+
img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white
88+
curve_origin = onp.array(
89+
[w // 2, 0.1 * h], dtype=onp.int32
90+
) # in x-y pixel coordinates
91+
# draw base
92+
cv2.rectangle(img, (0, h - curve_origin[1]), (w, h), color=base_color, thickness=-1)
93+
# transform robot poses to pixel coordinates
94+
# should be of shape (N, 2)
95+
curve = onp.array((curve_origin + chi_ps[:2, :].T * ppm), dtype=onp.int32)
96+
# invert the v pixel coordinate
97+
curve[:, 1] = h - curve[:, 1]
98+
cv2.polylines(img, [curve], isClosed=False, color=robot_color, thickness=10)
99+
100+
return img
101+
102+
103+
if __name__ == "__main__":
104+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
105+
planar_pcs.factory(num_segments, sym_exp_filepath, strain_selector, segment_actuation_selector=segment_actuation_selector)
106+
)
107+
actuation_mapping_fn = auxiliary_fns["actuation_mapping_fn"]
108+
# jit the functions
109+
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
110+
batched_forward_kinematics = vmap(
111+
forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1
112+
)
113+
114+
# test the actuation mapping function
115+
xi_eq = jnp.array([0.0, 0.0, 1.0])[None].repeat(num_segments, axis=0).flatten()
116+
B_xi = strain_basis
117+
# call the actuation mapping function
118+
A = actuation_mapping_fn(
119+
forward_kinematics_fn,
120+
auxiliary_fns["jacobian_fn"],
121+
params,
122+
B_xi,
123+
xi_eq,
124+
jnp.zeros_like(q0),
125+
)
126+
print("A =\n", A)
127+
128+
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
129+
u = jnp.array([1.0, 1.0])[None].repeat(num_segments, axis=0).flatten() # tendon tensions
130+
# u = 2e-1 * jnp.array([2.0, 0.0, 0.0, 1.0])
131+
print("u =\n", u)
132+
133+
ode_fn = ode_factory(dynamical_matrices_fn, params, u)
134+
term = ODETerm(ode_fn)
135+
136+
sol = diffeqsolve(
137+
term,
138+
solver=Tsit5(),
139+
t0=ts[0],
140+
t1=ts[-1],
141+
dt0=dt,
142+
y0=x0,
143+
max_steps=None,
144+
saveat=SaveAt(ts=video_ts),
145+
)
146+
147+
print("sol.ys =\n", sol.ys)
148+
# the evolution of the generalized coordinates
149+
q_ts = sol.ys[:, :n_q]
150+
# the evolution of the generalized velocities
151+
q_d_ts = sol.ys[:, n_q:]
152+
153+
# evaluate the forward kinematics along the trajectory
154+
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
155+
params, q_ts, jnp.array([jnp.sum(params["l"])])
156+
)
157+
# plot the configuration vs time
158+
plt.figure()
159+
for segment_idx in range(num_segments):
160+
plt.plot(
161+
video_ts, q_ts[:, 3 * segment_idx + 0],
162+
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
163+
)
164+
plt.plot(
165+
video_ts, q_ts[:, 3 * segment_idx + 1],
166+
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]"
167+
)
168+
plt.plot(
169+
video_ts, q_ts[:, 3 * segment_idx + 2],
170+
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
171+
)
172+
plt.xlabel("Time [s]")
173+
plt.ylabel("Configuration")
174+
plt.legend()
175+
plt.grid(True)
176+
plt.tight_layout()
177+
plt.show()
178+
# plot end-effector position vs time
179+
plt.figure()
180+
plt.plot(video_ts, chi_ee_ts[:, 0], label="x")
181+
plt.plot(video_ts, chi_ee_ts[:, 1], label="y")
182+
plt.xlabel("Time [s]")
183+
plt.ylabel("End-effector Position [m]")
184+
plt.legend()
185+
plt.grid(True)
186+
plt.box(True)
187+
plt.tight_layout()
188+
plt.show()
189+
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
190+
plt.figure()
191+
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis")
192+
plt.axis("equal")
193+
plt.grid(True)
194+
plt.xlabel("End-effector x [m]")
195+
plt.ylabel("End-effector y [m]")
196+
plt.colorbar(label="Time [s]")
197+
plt.tight_layout()
198+
plt.show()
199+
# plt.figure()
200+
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
201+
# plt.axis("equal")
202+
# plt.grid(True)
203+
# plt.xlabel("End-effector x [m]")
204+
# plt.ylabel("End-effector y [m]")
205+
# plt.tight_layout()
206+
# plt.show()
207+
208+
# plot the energy along the trajectory
209+
kinetic_energy_fn_vmapped = vmap(
210+
partial(auxiliary_fns["kinetic_energy_fn"], params)
211+
)
212+
potential_energy_fn_vmapped = vmap(
213+
partial(auxiliary_fns["potential_energy_fn"], params)
214+
)
215+
U_ts = potential_energy_fn_vmapped(q_ts)
216+
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
217+
plt.figure()
218+
plt.plot(video_ts, U_ts, label="Potential energy")
219+
plt.plot(video_ts, T_ts, label="Kinetic energy")
220+
plt.xlabel("Time [s]")
221+
plt.ylabel("Energy [J]")
222+
plt.legend()
223+
plt.grid(True)
224+
plt.box(True)
225+
plt.tight_layout()
226+
plt.show()
227+
228+
# create video
229+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
230+
video_path.parent.mkdir(parents=True, exist_ok=True)
231+
video = cv2.VideoWriter(
232+
str(video_path),
233+
fourcc,
234+
1 / (skip_step * dt), # fps
235+
(video_width, video_height),
236+
)
237+
238+
for time_idx, t in enumerate(video_ts):
239+
x = sol.ys[time_idx]
240+
img = draw_robot(
241+
batched_forward_kinematics,
242+
params,
243+
x[: (x.shape[0] // 2)],
244+
video_width,
245+
video_height,
246+
)
247+
video.write(img)
248+
249+
video.release()
250+
print(f"Video saved at {video_path}")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ name = "jsrm" # Required
1717
#
1818
# For a discussion on single-sourcing the version, see
1919
# https://packaging.python.org/guides/single-sourcing-package-version/
20-
version = "0.0.15" # Required
20+
version = "0.0.17" # Required
2121

2222
# This is a one-line description or tagline of what your project does. This
2323
# corresponds to the "Summary" metadata field:

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def symbolically_derive_planar_pcs_model(
4040
sp.symbols(f"r1:{num_segments + 1}", nonnegative=True)
4141
) # radius of each segment [m]
4242
g_syms = list(sp.symbols(f"g1:3")) # gravity vector
43+
d = sp.symbols("d", real=True, nonnegative=True) # # distance of the tendon from the neutral axis
4344

4445
# planar strains and their derivatives
4546
xi_syms = list(sp.symbols(f"xi1:{num_dof + 1}", nonzero=True)) # strains
@@ -59,6 +60,10 @@ def symbolically_derive_planar_pcs_model(
5960
chi_sms = []
6061
# Jacobians (positional + orientation) in each segment as a function of the point coordinate s and its time derivative
6162
J_sms, J_d_sms = [], []
63+
# tendon lengths for each segment as a function of the point coordinate s
64+
L_tend_sms = []
65+
# tendon length jacobians for each segment as a function of the point coordinate s
66+
J_tend_sms = []
6267
# cross-sectional area of each segment
6368
A = sp.zeros(num_segments, 1)
6469
# second area moment of inertia of each segment
@@ -74,13 +79,14 @@ def symbolically_derive_planar_pcs_model(
7479
# initialize
7580
th_prev = th0
7681
p_prev = sp.Matrix([0, 0])
82+
L_tend = 0 # tendon length
7783
for i in range(num_segments):
7884
# bending strain
79-
kappa = xi[3 * i]
85+
kappa_be = xi[3 * i]
8086
# shear strain
81-
sigma_x = xi[3 * i + 1]
87+
sigma_sh = xi[3 * i + 1]
8288
# axial strain
83-
sigma_y = xi[3 * i + 2]
89+
sigma_ax = xi[3 * i + 2]
8490

8591
# compute the cross-sectional area of the rod
8692
A[i] = sp.pi * r[i] ** 2
@@ -89,13 +95,13 @@ def symbolically_derive_planar_pcs_model(
8995
I[i] = A[i] ** 2 / (4 * sp.pi)
9096

9197
# planar orientation of robot as a function of the point s
92-
th = th_prev + s * kappa
98+
th = th_prev + s * kappa_be
9399

94100
# absolute rotation of link
95101
R = sp.Matrix([[sp.cos(th), -sp.sin(th)], [sp.sin(th), sp.cos(th)]])
96102

97103
# derivative of Cartesian position as function of the point s
98-
dp_ds = R @ sp.Matrix([sigma_x, sigma_y])
104+
dp_ds = R @ sp.Matrix([sigma_sh, sigma_ax])
99105

100106
# position along the current rod as a function of the point s
101107
p = p_prev + sp.integrate(dp_ds, (s, 0.0, s))
@@ -146,12 +152,24 @@ def symbolically_derive_planar_pcs_model(
146152
# add potential energy of segment to previous segments
147153
U_g = U_g + U_gi
148154

155+
# simplify derived tendon length
156+
L_tend = L_tend + s * sp.sqrt(sigma_sh**2 + (sigma_ax + kappa_be * d)**2)
157+
L_tend_sms.append(L_tend)
158+
print(f"L_tend of segment {i+1}:\n", L_tend)
159+
# take the derivative of the tendon length with respect to the configuration
160+
J_tend = sp.simplify(sp.Matrix([L_tend]).jacobian(xi))
161+
J_tend_sms.append(J_tend)
162+
print(f"J_tend of segment {i+1}:\n", J_tend)
163+
149164
# update the orientation for the next segment
150165
th_prev = th.subs(s, l[i])
151166

152167
# update the position for the next segment
153168
p_prev = p.subs(s, l[i])
154169

170+
# previous tendon length
171+
L_tend = L_tend.subs(s, l[i])
172+
155173
if simplify_expressions:
156174
# simplify mass matrix
157175
B = sp.simplify(B)
@@ -197,6 +215,8 @@ def symbolically_derive_planar_pcs_model(
197215
"C": C, # coriolis matrix
198216
"G": G, # gravity vector
199217
"U_g": U_g, # gravitational potential energy
218+
"L_tend_sms": L_tend_sms, # list of tendon lengths for each segment
219+
"J_tend_sms": J_tend_sms, # list of tendon length Jacobians for each segment
200220
},
201221
}
202222

438 KB
Binary file not shown.

src/jsrm/systems/planar_pcs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def actuation_mapping_fn(
242242
jacobian_fn: Callable,
243243
params: Dict[str, Array],
244244
B_xi: Array,
245+
xi_eq: Array,
245246
q: Array,
246247
) -> Array:
247248
"""
@@ -252,6 +253,7 @@ def actuation_mapping_fn(
252253
jacobian_fn: function to compute the Jacobian
253254
params: dictionary with robot parameters
254255
B_xi: strain basis matrix
256+
xi_eq: equilibrium strains as array of shape (n_xi,)
255257
q: configuration of the robot
256258
Returns:
257259
A: actuation matrix of shape (n_xi, n_xi) where n_xi is the number of strains.
@@ -365,7 +367,7 @@ def dynamical_matrices_fn(
365367
# compute the stiffness matrix
366368
K = stiffness_fn(params, B_xi, formulate_in_strain_space=True)
367369
# compute the actuation matrix
368-
A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, q)
370+
A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, xi_eq, q)
369371

370372
# dissipative matrix from the parameters
371373
D = params.get("D", jnp.zeros((n_xi, n_xi)))

0 commit comments

Comments
 (0)