Skip to content

Commit ab00835

Browse files
committed
Add plot backup, full-JAX PCS planar + Gaussian quadrature integration scheme
1 parent c3e87a1 commit ab00835

File tree

5 files changed

+1013
-40
lines changed

5 files changed

+1013
-40
lines changed

examples/figures/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.png

examples/simulate_planar_pcs.py

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,21 @@
1313

1414
import jsrm
1515
from jsrm import ode_factory
16-
from jsrm.systems import planar_pcs
16+
from jsrm.systems import planar_pcs, planar_pcs_num
1717

18-
num_segments = 1
18+
import time
1919

20-
# filepath to symbolic expressions
21-
sym_exp_filepath = (
22-
Path(jsrm.__file__).parent
23-
/ "symbolic_expressions"
24-
/ f"planar_pcs_ns-{num_segments}.dill"
25-
)
20+
num_segments = 2
21+
22+
type_of_derivation = "numeric" #"symbolic" #
23+
24+
if type_of_derivation == "symbolic":
25+
# filepath to symbolic expressions
26+
sym_exp_filepath = (
27+
Path(jsrm.__file__).parent
28+
/ "symbolic_expressions"
29+
/ f"planar_pcs_ns-{num_segments}.dill"
30+
)
2631

2732
# set parameters
2833
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
@@ -49,6 +54,9 @@
4954
# number of generalized coordinates
5055
n_q = q0.shape[0]
5156

57+
# ===================================================
58+
# For video generation
59+
# ======================
5260
# set simulation parameters
5361
dt = 1e-4 # time step
5462
ts = jnp.arange(0.0, 2, dt) # time steps
@@ -57,8 +65,7 @@
5765

5866
# video settings
5967
video_width, video_height = 700, 700 # img height and width
60-
video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4"
61-
68+
video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}-{('symb' if type_of_derivation == 'symbolic' else 'num')}.mp4"
6269

6370
def draw_robot(
6471
batched_forward_kinematics_fn: Callable,
@@ -95,39 +102,62 @@ def draw_robot(
95102

96103
return img
97104

105+
# ===================================================
106+
# For figure saving
107+
# ======================
108+
figures_path_parent = Path(__file__).parent / "figures"
109+
extension = f"planar_pcs_ns-{num_segments}-{('symb' if type_of_derivation == 'symbolic' else 'num')}"
110+
figures_path_parent.mkdir(parents=True, exist_ok=True)
98111

99112
if __name__ == "__main__":
100-
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
101-
planar_pcs.factory(sym_exp_filepath, strain_selector)
102-
)
113+
print("Type of derivation:", type_of_derivation)
114+
print("Number of segments:", num_segments, "\n")
115+
116+
print("Importing the planar PCS model...")
117+
timer_start = time.time()
118+
# import jsrm
119+
if type_of_derivation == "symbolic":
120+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
121+
planar_pcs.factory(sym_exp_filepath, strain_selector)
122+
)
123+
124+
elif type_of_derivation == "numeric":
125+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
126+
planar_pcs_num.factory(num_segments, strain_selector)
127+
)
128+
else:
129+
raise ValueError("type_of_derivation must be 'symbolic' or 'numeric'")
130+
103131
# jit the functions
104132
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
105133
batched_forward_kinematics = vmap(
106134
forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1
107135
)
136+
timer_end = time.time()
137+
print(f"Importing the planar PCS model took {timer_end - timer_start:.2f} seconds. \n")
108138

109-
# import matplotlib.pyplot as plt
110-
# plt.plot(chi_ps[0, :], chi_ps[1, :])
111-
# plt.axis("equal")
112-
# plt.grid(True)
113-
# plt.xlabel("x [m]")
114-
# plt.ylabel("y [m]")
115-
# plt.show()
116-
117-
# Displaying the image
118-
# window_name = f"Planar PCS with {num_segments} segments"
119-
# img = draw_robot(batched_forward_kinematics, params, q0, video_width, video_height)
120-
# cv2.namedWindow(window_name)
121-
# cv2.imshow(window_name, img)
122-
# cv2.waitKey()
123-
# cv2.destroyWindow(window_name)
139+
print("Evaluating the dynamical matrices...")
140+
timer_start = time.time()
141+
B, C, G, K, D, A = dynamical_matrices_fn(params, q0, jnp.zeros_like(q0))
142+
print("B =\n", B)
143+
print("C =\n", C)
144+
print("G =\n", G)
145+
print("K =\n", K)
146+
print("D =\n", D)
147+
print("A =\n", A)
148+
timer_end = time.time()
149+
print(f"Evaluating the dynamical matrices took {timer_end - timer_start:.2f} seconds. \n")
124150

125151
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
126152
tau = jnp.zeros_like(q0) # torques
127153

154+
timer_start = time.time()
128155
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
129156
term = ODETerm(ode_fn)
157+
timer_end = time.time()
130158

159+
print("Solving the ODE...")
160+
timer_start = time.time()
131161
sol = diffeqsolve(
132162
term,
133163
solver=Tsit5(),
@@ -144,13 +174,22 @@ def draw_robot(
144174
q_ts = sol.ys[:, :n_q]
145175
# the evolution of the generalized velocities
146176
q_d_ts = sol.ys[:, n_q:]
177+
178+
timer_end = time.time()
179+
print(f"Solving the ODE took {timer_end - timer_start:.2f} seconds. \n")
147180

181+
print("Evaluating the forward kinematics...")
182+
timer_start = time.time()
148183
# evaluate the forward kinematics along the trajectory
149184
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
150185
params, q_ts, jnp.array([jnp.sum(params["l"])])
151186
)
187+
timer_end = time.time()
188+
print(f"Evaluating the forward kinematics took {timer_end - timer_start:.2f} seconds. ")
189+
152190
# plot the configuration vs time
153191
plt.figure()
192+
plt.title("Configuration vs Time")
154193
for segment_idx in range(num_segments):
155194
plt.plot(
156195
video_ts, q_ts[:, 3 * segment_idx + 0],
@@ -169,9 +208,15 @@ def draw_robot(
169208
plt.legend()
170209
plt.grid(True)
171210
plt.tight_layout()
211+
plt.savefig(
212+
figures_path_parent / f"config_vs_time_{extension}.png", bbox_inches="tight", dpi=300
213+
)
214+
print("Figures saved at", figures_path_parent / f"config_vs_time_{extension}.png")
172215
plt.show()
216+
173217
# plot end-effector position vs time
174218
plt.figure()
219+
plt.title("End-effector position vs Time")
175220
plt.plot(video_ts, chi_ee_ts[:, 0], label="x")
176221
plt.plot(video_ts, chi_ee_ts[:, 1], label="y")
177222
plt.xlabel("Time [s]")
@@ -180,26 +225,30 @@ def draw_robot(
180225
plt.grid(True)
181226
plt.box(True)
182227
plt.tight_layout()
228+
plt.savefig(
229+
figures_path_parent / f"end_effector_position_vs_time_{extension}.png", bbox_inches="tight", dpi=300
230+
)
231+
print("Figures saved at", figures_path_parent / f"end_effector_position_vs_time_{extension}.png ")
183232
plt.show()
233+
184234
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
185235
plt.figure()
236+
plt.title("End-effector position in the x-y plane")
186237
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis")
187238
plt.axis("equal")
188239
plt.grid(True)
189240
plt.xlabel("End-effector x [m]")
190241
plt.ylabel("End-effector y [m]")
191242
plt.colorbar(label="Time [s]")
192243
plt.tight_layout()
244+
plt.savefig(
245+
figures_path_parent / f"end_effector_position_xy_{extension}.png", bbox_inches="tight", dpi=300
246+
)
247+
print("Figures saved at", figures_path_parent / f"end_effector_position_xy_{extension}.png \n")
193248
plt.show()
194-
# plt.figure()
195-
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
196-
# plt.axis("equal")
197-
# plt.grid(True)
198-
# plt.xlabel("End-effector x [m]")
199-
# plt.ylabel("End-effector y [m]")
200-
# plt.tight_layout()
201-
# plt.show()
202-
249+
250+
print("Evaluating the energy...")
251+
timer_start = time.time()
203252
# plot the energy along the trajectory
204253
kinetic_energy_fn_vmapped = vmap(
205254
partial(auxiliary_fns["kinetic_energy_fn"], params)
@@ -209,7 +258,13 @@ def draw_robot(
209258
)
210259
U_ts = potential_energy_fn_vmapped(q_ts)
211260
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
261+
timer_end = time.time()
262+
print(f"Evaluating the energy took {timer_end - timer_start:.2f} seconds.")
263+
264+
# plot the energy vs time
212265
plt.figure()
266+
plt.title("Energy vs Time")
267+
plt.plot(video_ts, U_ts + T_ts, label="Total energy")
213268
plt.plot(video_ts, U_ts, label="Potential energy")
214269
plt.plot(video_ts, T_ts, label="Kinetic energy")
215270
plt.xlabel("Time [s]")
@@ -218,8 +273,14 @@ def draw_robot(
218273
plt.grid(True)
219274
plt.box(True)
220275
plt.tight_layout()
276+
plt.savefig(
277+
figures_path_parent / f"energy_vs_time_{extension}.png", bbox_inches="tight", dpi=300
278+
)
279+
print("Figures saved at", figures_path_parent / f"energy_vs_time_{extension}.png \n")
221280
plt.show()
222281

282+
print("Drawing the robot...")
283+
timer_start = time.time()
223284
# create video
224285
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
225286
video_path.parent.mkdir(parents=True, exist_ok=True)
@@ -242,4 +303,8 @@ def draw_robot(
242303
video.write(img)
243304

244305
video.release()
245-
print(f"Video saved at {video_path}")
306+
timer_end = time.time()
307+
print(f"Drawing the robot took {timer_end - timer_start:.2f} seconds.")
308+
print(f"Video saved at {video_path}. \n")
309+
310+

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def symbolically_derive_planar_pcs_model(
6060
# Jacobians (positional + orientation) in each segment as a function of the point coordinate s and its time derivative
6161
J_sms, J_d_sms = [], []
6262
# cross-sectional area of each segment
63-
A = sp.zeros(num_segments)
63+
A = sp.zeros(num_segments, 1)
6464
# second area moment of inertia of each segment
65-
I = sp.zeros(num_segments)
65+
I = sp.zeros(num_segments, 1)
6666
# inertia matrix
6767
B = sp.zeros(num_dof, num_dof)
6868
# potential energy

0 commit comments

Comments
 (0)