Skip to content

Commit 7f27d15

Browse files
committed
Correction test_planar_pcs to be consistent with
the class Start of class for Planar HSA (to be continued)
1 parent 1bf71b5 commit 7f27d15

File tree

6 files changed

+1826
-267
lines changed

6 files changed

+1826
-267
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from jax import jit, vmap
4+
5+
import jsrm
6+
from jsrm.systems.class_planar_hsa import PlanarHSA
7+
from jsrm.parameters.hsa_params import PARAMS_FPU_CONTROL, PARAMS_FPU_HYSTERESIS_CONTROL
8+
9+
from typing import Callable
10+
from jax import Array
11+
12+
import numpy as onp
13+
14+
from diffrax import Tsit5
15+
import cv2 # importing cv2
16+
17+
from pathlib import Path
18+
19+
jax.config.update("jax_enable_x64", True) # double precision
20+
jnp.set_printoptions(
21+
threshold=jnp.inf,
22+
linewidth=jnp.inf,
23+
formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"},
24+
)
25+
26+
27+
def draw_robot(
28+
robot: PlanarHSA,
29+
q: Array,
30+
width: int = 700,
31+
height: int = 700,
32+
num_points: int = 50,
33+
) -> onp.ndarray:
34+
"""
35+
Draw the robot in OpenCV.
36+
Args:
37+
robot:
38+
q: configuration as shape (3, )
39+
width: image width
40+
height: image height
41+
num_points: number of points to plot along the length of the robot
42+
"""
43+
# plotting in OpenCV
44+
h, w = height, width # img height and width
45+
ppm = h / (
46+
2.0 * jnp.sum(robot.params["lpc"] + robot.params["l"] + robot.params["ldc"])
47+
) # pixel per meter
48+
base_color = (0, 0, 0) # black base color in BGR
49+
backbone_color = (255, 0, 0) # blue robot color in BGR
50+
rod_color = (0, 255, 0) # green rod color in BGR
51+
platform_color = (0, 0, 255) # red platform color in BGR
52+
53+
batched_forward_kinematics_virtual_backbone_fn = vmap(
54+
robot.forward_kinematics_virtual_backbone_fn,
55+
in_axes=(None, 0), out_axes=-1
56+
)
57+
batched_forward_kinematics_rod_fn = vmap(
58+
robot.forward_kinematics_rod_fn,
59+
in_axes=(None, 0, None), out_axes=-1
60+
)
61+
batched_forward_kinematics_platform_fn = vmap(
62+
robot.forward_kinematics_platform_fn,
63+
in_axes=(None, 0), out_axes=0
64+
)
65+
66+
L_max = jnp.sum(robot.params["l"]) # total length of the robot
67+
# we use for plotting N points along the length of the robot
68+
s_ps = jnp.linspace(0, L_max, num_points)
69+
70+
# poses along the robot of shape (3, N)
71+
chiv_ps = batched_forward_kinematics_virtual_backbone_fn(q, s_ps) # poses of virtual backbone
72+
chiL_ps = batched_forward_kinematics_rod_fn(q, s_ps, 0) # poses of left rod
73+
chiR_ps = batched_forward_kinematics_rod_fn(q, s_ps, 1) # poses of left rod
74+
# poses of the platforms
75+
chip_ps = batched_forward_kinematics_platform_fn(q, jnp.arange(0, robot.num_segments))
76+
77+
img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white
78+
uv_robot_origin = onp.array(
79+
[w // 2, h * (1 - 0.1)], dtype=jnp.int32
80+
) # in x-y pixel coordinates
81+
uv_robot_origin_jax = jnp.array(uv_robot_origin)
82+
83+
@jit
84+
def chi2u(chi: Array) -> Array:
85+
"""
86+
Map Cartesian coordinates to pixel coordinates.
87+
Args:
88+
chi: Cartesian poses of shape (3)
89+
90+
Returns:
91+
uv: pixel coordinates of shape (2)
92+
"""
93+
uv_off = jnp.array((chi[1:] * ppm), dtype=jnp.int32)
94+
# invert the v pixel coordinate
95+
uv_off = uv_off.at[1].set(-uv_off[1])
96+
# invert the v pixel coordinate
97+
uv = uv_robot_origin_jax + uv_off
98+
return uv
99+
100+
batched_chi2u = vmap(chi2u, in_axes=-1, out_axes=0)
101+
102+
# draw base
103+
cv2.rectangle(img, (0, uv_robot_origin[1]), (w, h), color=base_color, thickness=-1)
104+
105+
# draw the virtual backbone
106+
# add the first point of the proximal cap and the last point of the distal cap
107+
chiv_ps = jnp.concatenate(
108+
[
109+
(chiv_ps[:, 0] - jnp.array([0.0, 0.0, params["lpc"][0]])).reshape(3, 1),
110+
chiv_ps,
111+
(
112+
chiv_ps[:, -1]
113+
+ jnp.array(
114+
[
115+
chiv_ps[2, -1],
116+
-jnp.sin(chiv_ps[2, -1]) * params["ldc"][-1],
117+
jnp.cos(chiv_ps[2, -1]) * params["ldc"][-1],
118+
]
119+
)
120+
).reshape(3, 1),
121+
],
122+
axis=1,
123+
)
124+
curve_virtual_backbone = onp.array(batched_chi2u(chiv_ps))
125+
cv2.polylines(
126+
img, [curve_virtual_backbone], isClosed=False, color=backbone_color, thickness=5
127+
)
128+
129+
# draw the rods
130+
# add the first point of the proximal cap and the last point of the distal cap
131+
chiL_ps = jnp.concatenate(
132+
[
133+
(chiL_ps[:, 0] - jnp.array([0.0, 0.0, params["lpc"][0]])).reshape(3, 1),
134+
chiL_ps,
135+
(
136+
chiL_ps[:, -1]
137+
+ jnp.array(
138+
[
139+
chiL_ps[2, -1],
140+
-jnp.sin(chiL_ps[2, -1]) * params["ldc"][-1],
141+
jnp.cos(chiL_ps[2, -1]) * params["ldc"][-1],
142+
]
143+
)
144+
).reshape(3, 1),
145+
],
146+
axis=1,
147+
)
148+
curve_rod_left = onp.array(batched_chi2u(chiL_ps))
149+
cv2.polylines(
150+
img,
151+
[curve_rod_left],
152+
isClosed=False,
153+
color=rod_color,
154+
thickness=10,
155+
# thickness=2*int(ppm * params["rout"].mean(axis=0)[0])
156+
)
157+
# add the first point of the proximal cap and the last point of the distal cap
158+
chiR_ps = jnp.concatenate(
159+
[
160+
(chiR_ps[:, 0] - jnp.array([0.0, 0.0, params["lpc"][0]])).reshape(3, 1),
161+
chiR_ps,
162+
(
163+
chiR_ps[:, -1]
164+
+ jnp.array(
165+
[
166+
chiR_ps[2, -1],
167+
-jnp.sin(chiR_ps[2, -1]) * params["ldc"][-1],
168+
jnp.cos(chiR_ps[2, -1]) * params["ldc"][-1],
169+
]
170+
)
171+
).reshape(3, 1),
172+
],
173+
axis=1,
174+
)
175+
curve_rod_right = onp.array(batched_chi2u(chiR_ps))
176+
cv2.polylines(img, [curve_rod_right], isClosed=False, color=rod_color, thickness=10)
177+
178+
# draw the platform
179+
for i in range(chip_ps.shape[0]):
180+
# iterate over the platforms
181+
platform_R = jnp.array(
182+
[
183+
[jnp.cos(chip_ps[i, 0]), -jnp.sin(chip_ps[i, 0])],
184+
[jnp.sin(chip_ps[i, 0]), jnp.cos(chip_ps[i, 0])],
185+
]
186+
) # rotation matrix for the platform
187+
platform_llc = chip_ps[i, 1:] + platform_R @ jnp.array(
188+
[
189+
-params["pcudim"][i, 1] / 2, # go half the width to the left
190+
-params["pcudim"][i, 2] / 2, # go half the height down
191+
]
192+
) # lower left corner of the platform
193+
platform_ulc = chip_ps[i, 1:] + platform_R @ jnp.array(
194+
[
195+
-params["pcudim"][i, 1] / 2, # go half the width to the left
196+
+params["pcudim"][i, 2] / 2, # go half the height down
197+
]
198+
) # upper left corner of the platform
199+
platform_urc = chip_ps[i, 1:] + platform_R @ jnp.array(
200+
[
201+
+params["pcudim"][i, 1] / 2, # go half the width to the left
202+
+params["pcudim"][i, 2] / 2, # go half the height down
203+
]
204+
) # upper right corner of the platform
205+
platform_lrc = chip_ps[i, 1:] + platform_R @ jnp.array(
206+
[
207+
+params["pcudim"][i, 1] / 2, # go half the width to the left
208+
-params["pcudim"][i, 2] / 2, # go half the height down
209+
]
210+
) # lower right corner of the platform
211+
platform_curve = jnp.stack(
212+
[platform_llc, platform_ulc, platform_urc, platform_lrc, platform_llc],
213+
axis=1,
214+
)
215+
# cv2.polylines(img, [onp.array(batched_chi2u(platform_curve))], isClosed=True, color=platform_color, thickness=5)
216+
cv2.fillPoly(
217+
img, [onp.array(batched_chi2u(platform_curve))], color=platform_color
218+
)
219+
220+
return img
221+
222+
223+
if __name__ == "__main__":
224+
num_segments = 1
225+
num_rods_per_segment = 2
226+
227+
# filepath to symbolic expressions
228+
sym_exp_filepath = (
229+
Path(jsrm.__file__).parent
230+
/ "symbolic_expressions"
231+
/ f"planar_hsa_ns-{num_segments}_nrs-{num_rods_per_segment}.dill"
232+
)
233+
234+
# activate all strains (i.e. bending, shear, and axial)
235+
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
236+
consider_hysteresis = True
237+
238+
params = PARAMS_FPU_HYSTERESIS_CONTROL if consider_hysteresis else PARAMS_FPU_CONTROL
239+
# increase damping for simulation stability
240+
params["zetab"] = 5 * params["zetab"]
241+
params["zetash"] = 5 * params["zetash"]
242+
params["zetaa"] = 5 * params["zetaa"]
243+
244+
# ======================================================
245+
# Robot initialization
246+
# ======================================================
247+
robot = PlanarHSA(
248+
sym_exp_filepath=sym_exp_filepath,
249+
params=params,
250+
strain_selector=strain_selector,
251+
consider_hysteresis=consider_hysteresis,
252+
)
253+
254+
# =====================================================
255+
# Simulation upon time
256+
# =====================================================
257+
# Initial configuration
258+
q0 = jnp.array([jnp.pi, 0.0, 0.0])
259+
# Initial velocities
260+
qd0 = jnp.zeros_like(q0)
261+
# Motor actuation angles
262+
phi = jnp.array([jnp.pi, jnp.pi / 2])
263+
264+
# Displaying the image
265+
window_name = f"Planar HSA with {num_segments} segments"
266+
img = draw_robot(
267+
robot,
268+
q = q0,
269+
)
270+
271+
272+
# Simulation time parameters
273+
t0 = 0.0
274+
t1 = 5.0
275+
dt = 5e-5 # time step
276+
skip_step = 100 # how many time steps to skip in between video frames
277+
278+
# Solver
279+
solver = Tsit5()

examples/simulate_planar_pcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jsrm.systems.planar_pcs import PlanarPCS
44
import jax.numpy as jnp
55

6-
from typing import Callable, Dict
6+
from typing import Callable
77
from jax import Array
88

99
import numpy as onp

0 commit comments

Comments
 (0)