Skip to content

Commit 1b3d3a9

Browse files
committed
Start working on tendon_acutated_planar_pcs
1 parent 1bde91a commit 1b3d3a9

File tree

6 files changed

+410
-6
lines changed

6 files changed

+410
-6
lines changed

examples/derive_planar_pcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jsrm
44
from jsrm.symbolic_derivation.planar_pcs import symbolically_derive_planar_pcs_model
55

6-
NUM_SEGMENTS = 1
6+
NUM_SEGMENTS = 2
77

88
if __name__ == "__main__":
99
sym_exp_filepath = (
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import cv2 # importing cv2
2+
from functools import partial
3+
import jax
4+
5+
jax.config.update("jax_enable_x64", True) # double precision
6+
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
7+
from jax import Array, vmap
8+
from jax import numpy as jnp
9+
import matplotlib.pyplot as plt
10+
import numpy as onp
11+
from pathlib import Path
12+
from typing import Callable, Dict
13+
14+
import jsrm
15+
from jsrm import ode_factory
16+
from jsrm.systems import tendon_actuated_planar_pcs as planar_pcs
17+
18+
num_segments = 1
19+
20+
# filepath to symbolic expressions
21+
sym_exp_filepath = (
22+
Path(jsrm.__file__).parent
23+
/ "symbolic_expressions"
24+
/ f"planar_pcs_ns-{num_segments}.dill"
25+
)
26+
27+
# set parameters
28+
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
29+
params = {
30+
"th0": jnp.array(0.0), # initial orientation angle [rad]
31+
"l": 1e-1 * jnp.ones((num_segments,)),
32+
"r": 2e-2 * jnp.ones((num_segments,)),
33+
"rho": rho,
34+
"g": jnp.array([0.0, 9.81]),
35+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
36+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
37+
"d": 2e-2 * jnp.array([[1.0, -1.0]]).repeat(num_segments, axis=0), # distance of tendons from the central axis [m]
38+
}
39+
print("params d =\n", params["d"])
40+
params["D"] = 1e-3 * jnp.diag(
41+
(jnp.repeat(
42+
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
43+
) * params["l"][:, None]).flatten()
44+
)
45+
46+
# activate all strains (i.e. bending, shear, and axial)
47+
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
48+
49+
# define initial configuration
50+
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
51+
# number of generalized coordinates
52+
n_q = q0.shape[0]
53+
54+
# set simulation parameters
55+
dt = 1e-4 # time step
56+
ts = jnp.arange(0.0, 2, dt) # time steps
57+
skip_step = 10 # how many time steps to skip in between video frames
58+
video_ts = ts[::skip_step] # time steps for video
59+
60+
# video settings
61+
video_width, video_height = 700, 700 # img height and width
62+
video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4"
63+
64+
65+
def draw_robot(
66+
batched_forward_kinematics_fn: Callable,
67+
params: Dict[str, Array],
68+
q: Array,
69+
width: int,
70+
height: int,
71+
num_points: int = 50,
72+
) -> onp.ndarray:
73+
# plotting in OpenCV
74+
h, w = height, width # img height and width
75+
ppm = h / (2.0 * jnp.sum(params["l"])) # pixel per meter
76+
base_color = (0, 0, 0) # black robot_color in BGR
77+
robot_color = (255, 0, 0) # black robot_color in BGR
78+
79+
# we use for plotting N points along the length of the robot
80+
s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points)
81+
82+
# poses along the robot of shape (3, N)
83+
chi_ps = batched_forward_kinematics_fn(params, q, s_ps)
84+
85+
img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white
86+
curve_origin = onp.array(
87+
[w // 2, 0.1 * h], dtype=onp.int32
88+
) # in x-y pixel coordinates
89+
# draw base
90+
cv2.rectangle(img, (0, h - curve_origin[1]), (w, h), color=base_color, thickness=-1)
91+
# transform robot poses to pixel coordinates
92+
# should be of shape (N, 2)
93+
curve = onp.array((curve_origin + chi_ps[:2, :].T * ppm), dtype=onp.int32)
94+
# invert the v pixel coordinate
95+
curve[:, 1] = h - curve[:, 1]
96+
cv2.polylines(img, [curve], isClosed=False, color=robot_color, thickness=10)
97+
98+
return img
99+
100+
101+
if __name__ == "__main__":
102+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
103+
planar_pcs.factory(sym_exp_filepath, strain_selector)
104+
)
105+
# jit the functions
106+
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
107+
batched_forward_kinematics = vmap(
108+
forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1
109+
)
110+
111+
# import matplotlib.pyplot as plt
112+
# plt.plot(chi_ps[0, :], chi_ps[1, :])
113+
# plt.axis("equal")
114+
# plt.grid(True)
115+
# plt.xlabel("x [m]")
116+
# plt.ylabel("y [m]")
117+
# plt.show()
118+
119+
# Displaying the image
120+
# window_name = f"Planar PCS with {num_segments} segments"
121+
# img = draw_robot(batched_forward_kinematics, params, q0, video_width, video_height)
122+
# cv2.namedWindow(window_name)
123+
# cv2.imshow(window_name, img)
124+
# cv2.waitKey()
125+
# cv2.destroyWindow(window_name)
126+
127+
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
128+
tau = jnp.zeros_like(q0) # torques
129+
130+
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
131+
term = ODETerm(ode_fn)
132+
133+
sol = diffeqsolve(
134+
term,
135+
solver=Tsit5(),
136+
t0=ts[0],
137+
t1=ts[-1],
138+
dt0=dt,
139+
y0=x0,
140+
max_steps=None,
141+
saveat=SaveAt(ts=video_ts),
142+
)
143+
144+
print("sol.ys =\n", sol.ys)
145+
# the evolution of the generalized coordinates
146+
q_ts = sol.ys[:, :n_q]
147+
# the evolution of the generalized velocities
148+
q_d_ts = sol.ys[:, n_q:]
149+
150+
# evaluate the forward kinematics along the trajectory
151+
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
152+
params, q_ts, jnp.array([jnp.sum(params["l"])])
153+
)
154+
# plot the configuration vs time
155+
plt.figure()
156+
for segment_idx in range(num_segments):
157+
plt.plot(
158+
video_ts, q_ts[:, 3 * segment_idx + 0],
159+
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
160+
)
161+
plt.plot(
162+
video_ts, q_ts[:, 3 * segment_idx + 1],
163+
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]"
164+
)
165+
plt.plot(
166+
video_ts, q_ts[:, 3 * segment_idx + 2],
167+
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
168+
)
169+
plt.xlabel("Time [s]")
170+
plt.ylabel("Configuration")
171+
plt.legend()
172+
plt.grid(True)
173+
plt.tight_layout()
174+
plt.show()
175+
# plot end-effector position vs time
176+
plt.figure()
177+
plt.plot(video_ts, chi_ee_ts[:, 0], label="x")
178+
plt.plot(video_ts, chi_ee_ts[:, 1], label="y")
179+
plt.xlabel("Time [s]")
180+
plt.ylabel("End-effector Position [m]")
181+
plt.legend()
182+
plt.grid(True)
183+
plt.box(True)
184+
plt.tight_layout()
185+
plt.show()
186+
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
187+
plt.figure()
188+
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis")
189+
plt.axis("equal")
190+
plt.grid(True)
191+
plt.xlabel("End-effector x [m]")
192+
plt.ylabel("End-effector y [m]")
193+
plt.colorbar(label="Time [s]")
194+
plt.tight_layout()
195+
plt.show()
196+
# plt.figure()
197+
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
198+
# plt.axis("equal")
199+
# plt.grid(True)
200+
# plt.xlabel("End-effector x [m]")
201+
# plt.ylabel("End-effector y [m]")
202+
# plt.tight_layout()
203+
# plt.show()
204+
205+
# plot the energy along the trajectory
206+
kinetic_energy_fn_vmapped = vmap(
207+
partial(auxiliary_fns["kinetic_energy_fn"], params)
208+
)
209+
potential_energy_fn_vmapped = vmap(
210+
partial(auxiliary_fns["potential_energy_fn"], params)
211+
)
212+
U_ts = potential_energy_fn_vmapped(q_ts)
213+
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
214+
plt.figure()
215+
plt.plot(video_ts, U_ts, label="Potential energy")
216+
plt.plot(video_ts, T_ts, label="Kinetic energy")
217+
plt.xlabel("Time [s]")
218+
plt.ylabel("Energy [J]")
219+
plt.legend()
220+
plt.grid(True)
221+
plt.box(True)
222+
plt.tight_layout()
223+
plt.show()
224+
225+
# create video
226+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
227+
video_path.parent.mkdir(parents=True, exist_ok=True)
228+
video = cv2.VideoWriter(
229+
str(video_path),
230+
fourcc,
231+
1 / (skip_step * dt), # fps
232+
(video_width, video_height),
233+
)
234+
235+
for time_idx, t in enumerate(video_ts):
236+
x = sol.ys[time_idx]
237+
img = draw_robot(
238+
batched_forward_kinematics,
239+
params,
240+
x[: (x.shape[0] // 2)],
241+
video_width,
242+
video_height,
243+
)
244+
video.write(img)
245+
246+
video.release()
247+
print(f"Video saved at {video_path}")

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def symbolically_derive_planar_pcs_model(
4040
sp.symbols(f"r1:{num_segments + 1}", nonnegative=True)
4141
) # radius of each segment [m]
4242
g_syms = list(sp.symbols(f"g1:3")) # gravity vector
43+
d = sp.symbols("d", real=True, nonnegative=True) # # distance of the tendon from the neutral axis
4344

4445
# planar strains and their derivatives
4546
xi_syms = list(sp.symbols(f"xi1:{num_dof + 1}", nonzero=True)) # strains
@@ -59,6 +60,10 @@ def symbolically_derive_planar_pcs_model(
5960
chi_sms = []
6061
# Jacobians (positional + orientation) in each segment as a function of the point coordinate s and its time derivative
6162
J_sms, J_d_sms = [], []
63+
# tendon lengths for each segment as a function of the point coordinate s
64+
L_tend_sms = []
65+
# tendon length jacobians for each segment as a function of the point coordinate s
66+
J_tend_sms = []
6267
# cross-sectional area of each segment
6368
A = sp.zeros(num_segments)
6469
# second area moment of inertia of each segment
@@ -74,13 +79,14 @@ def symbolically_derive_planar_pcs_model(
7479
# initialize
7580
th_prev = th0
7681
p_prev = sp.Matrix([0, 0])
82+
L_tend = 0 # tendon length
7783
for i in range(num_segments):
7884
# bending strain
79-
kappa = xi[3 * i]
85+
kappa_be = xi[3 * i]
8086
# shear strain
81-
sigma_x = xi[3 * i + 1]
87+
sigma_sh = xi[3 * i + 1]
8288
# axial strain
83-
sigma_y = xi[3 * i + 2]
89+
sigma_ax = xi[3 * i + 2]
8490

8591
# compute the cross-sectional area of the rod
8692
A[i] = sp.pi * r[i] ** 2
@@ -89,13 +95,13 @@ def symbolically_derive_planar_pcs_model(
8995
I[i] = A[i] ** 2 / (4 * sp.pi)
9096

9197
# planar orientation of robot as a function of the point s
92-
th = th_prev + s * kappa
98+
th = th_prev + s * kappa_be
9399

94100
# absolute rotation of link
95101
R = sp.Matrix([[sp.cos(th), -sp.sin(th)], [sp.sin(th), sp.cos(th)]])
96102

97103
# derivative of Cartesian position as function of the point s
98-
dp_ds = R @ sp.Matrix([sigma_x, sigma_y])
104+
dp_ds = R @ sp.Matrix([sigma_sh, sigma_ax])
99105

100106
# position along the current rod as a function of the point s
101107
p = p_prev + sp.integrate(dp_ds, (s, 0.0, s))
@@ -146,12 +152,24 @@ def symbolically_derive_planar_pcs_model(
146152
# add potential energy of segment to previous segments
147153
U_g = U_g + U_gi
148154

155+
# simplify derived tendon length
156+
L_tend = L_tend + s * (1 + kappa_be * d) * sp.sqrt(sigma_sh**2 + sigma_ax**2)
157+
L_tend_sms.append(L_tend)
158+
print(f"L_tend of segment {i+1}:\n", L_tend)
159+
# take the derivative of the tendon length with respect to the configuration
160+
J_tend = sp.simplify(sp.Matrix([L_tend]).jacobian(xi))
161+
J_tend_sms.append(J_tend)
162+
print(f"J_tend of segment {i+1}:\n", J_tend)
163+
149164
# update the orientation for the next segment
150165
th_prev = th.subs(s, l[i])
151166

152167
# update the position for the next segment
153168
p_prev = p.subs(s, l[i])
154169

170+
# previous tendon length
171+
L_tend = L_tend.subs(s, l[i])
172+
155173
if simplify_expressions:
156174
# simplify mass matrix
157175
B = sp.simplify(B)
@@ -174,6 +192,7 @@ def symbolically_derive_planar_pcs_model(
174192
"r": r_syms,
175193
"rho": rho_syms,
176194
"g": g_syms,
195+
"d": d,
177196
},
178197
"state_syms": {
179198
"xi": xi_syms,
@@ -197,6 +216,8 @@ def symbolically_derive_planar_pcs_model(
197216
"C": C, # coriolis matrix
198217
"G": G, # gravity vector
199218
"U_g": U_g, # gravitational potential energy
219+
"L_tend_sms": L_tend_sms, # list of tendon lengths for each segment
220+
"J_tend_sms": J_tend_sms, # list of tendon length Jacobians for each segment
200221
},
201222
}
202223

-178 Bytes
Binary file not shown.
-5.21 KB
Binary file not shown.

0 commit comments

Comments
 (0)