1
+ import jax
2
+ import jax .numpy as jnp
3
+ from jax import jit , vmap
4
+
5
+ import jsrm
6
+ from jsrm .systems .class_planar_hsa import PlanarHSA
7
+ from jsrm .parameters .hsa_params import PARAMS_FPU_CONTROL , PARAMS_FPU_HYSTERESIS_CONTROL
8
+
9
+ from typing import Callable
10
+ from jax import Array
11
+
12
+ import numpy as onp
13
+
14
+ from diffrax import Tsit5
15
+ import cv2 # importing cv2
16
+
17
+ from pathlib import Path
18
+
19
+ jax .config .update ("jax_enable_x64" , True ) # double precision
20
+ jnp .set_printoptions (
21
+ threshold = jnp .inf ,
22
+ linewidth = jnp .inf ,
23
+ formatter = {"float_kind" : lambda x : "0" if x == 0 else f"{ x :.2e} " },
24
+ )
25
+
26
+
27
+ def draw_robot (
28
+ robot : PlanarHSA ,
29
+ q : Array ,
30
+ width : int = 700 ,
31
+ height : int = 700 ,
32
+ num_points : int = 50 ,
33
+ ) -> onp .ndarray :
34
+ """
35
+ Draw the robot in OpenCV.
36
+ Args:
37
+ robot:
38
+ q: configuration as shape (3, )
39
+ width: image width
40
+ height: image height
41
+ num_points: number of points to plot along the length of the robot
42
+ """
43
+ # plotting in OpenCV
44
+ h , w = height , width # img height and width
45
+ ppm = h / (
46
+ 2.0 * jnp .sum (robot .params ["lpc" ] + robot .params ["l" ] + robot .params ["ldc" ])
47
+ ) # pixel per meter
48
+ base_color = (0 , 0 , 0 ) # black base color in BGR
49
+ backbone_color = (255 , 0 , 0 ) # blue robot color in BGR
50
+ rod_color = (0 , 255 , 0 ) # green rod color in BGR
51
+ platform_color = (0 , 0 , 255 ) # red platform color in BGR
52
+
53
+ batched_forward_kinematics_virtual_backbone_fn = vmap (
54
+ robot .forward_kinematics_virtual_backbone_fn ,
55
+ in_axes = (None , 0 ), out_axes = - 1
56
+ )
57
+ batched_forward_kinematics_rod_fn = vmap (
58
+ robot .forward_kinematics_rod_fn ,
59
+ in_axes = (None , 0 , None ), out_axes = - 1
60
+ )
61
+ batched_forward_kinematics_platform_fn = vmap (
62
+ robot .forward_kinematics_platform_fn ,
63
+ in_axes = (None , 0 ), out_axes = 0
64
+ )
65
+
66
+ L_max = jnp .sum (robot .params ["l" ]) # total length of the robot
67
+ # we use for plotting N points along the length of the robot
68
+ s_ps = jnp .linspace (0 , L_max , num_points )
69
+
70
+ # poses along the robot of shape (3, N)
71
+ chiv_ps = batched_forward_kinematics_virtual_backbone_fn (q , s_ps ) # poses of virtual backbone
72
+ chiL_ps = batched_forward_kinematics_rod_fn (q , s_ps , 0 ) # poses of left rod
73
+ chiR_ps = batched_forward_kinematics_rod_fn (q , s_ps , 1 ) # poses of left rod
74
+ # poses of the platforms
75
+ chip_ps = batched_forward_kinematics_platform_fn (q , jnp .arange (0 , robot .num_segments ))
76
+
77
+ img = 255 * onp .ones ((w , h , 3 ), dtype = jnp .uint8 ) # initialize background to white
78
+ uv_robot_origin = onp .array (
79
+ [w // 2 , h * (1 - 0.1 )], dtype = jnp .int32
80
+ ) # in x-y pixel coordinates
81
+ uv_robot_origin_jax = jnp .array (uv_robot_origin )
82
+
83
+ @jit
84
+ def chi2u (chi : Array ) -> Array :
85
+ """
86
+ Map Cartesian coordinates to pixel coordinates.
87
+ Args:
88
+ chi: Cartesian poses of shape (3)
89
+
90
+ Returns:
91
+ uv: pixel coordinates of shape (2)
92
+ """
93
+ uv_off = jnp .array ((chi [1 :] * ppm ), dtype = jnp .int32 )
94
+ # invert the v pixel coordinate
95
+ uv_off = uv_off .at [1 ].set (- uv_off [1 ])
96
+ # invert the v pixel coordinate
97
+ uv = uv_robot_origin_jax + uv_off
98
+ return uv
99
+
100
+ batched_chi2u = vmap (chi2u , in_axes = - 1 , out_axes = 0 )
101
+
102
+ # draw base
103
+ cv2 .rectangle (img , (0 , uv_robot_origin [1 ]), (w , h ), color = base_color , thickness = - 1 )
104
+
105
+ # draw the virtual backbone
106
+ # add the first point of the proximal cap and the last point of the distal cap
107
+ chiv_ps = jnp .concatenate (
108
+ [
109
+ (chiv_ps [:, 0 ] - jnp .array ([0.0 , 0.0 , params ["lpc" ][0 ]])).reshape (3 , 1 ),
110
+ chiv_ps ,
111
+ (
112
+ chiv_ps [:, - 1 ]
113
+ + jnp .array (
114
+ [
115
+ chiv_ps [2 , - 1 ],
116
+ - jnp .sin (chiv_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
117
+ jnp .cos (chiv_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
118
+ ]
119
+ )
120
+ ).reshape (3 , 1 ),
121
+ ],
122
+ axis = 1 ,
123
+ )
124
+ curve_virtual_backbone = onp .array (batched_chi2u (chiv_ps ))
125
+ cv2 .polylines (
126
+ img , [curve_virtual_backbone ], isClosed = False , color = backbone_color , thickness = 5
127
+ )
128
+
129
+ # draw the rods
130
+ # add the first point of the proximal cap and the last point of the distal cap
131
+ chiL_ps = jnp .concatenate (
132
+ [
133
+ (chiL_ps [:, 0 ] - jnp .array ([0.0 , 0.0 , params ["lpc" ][0 ]])).reshape (3 , 1 ),
134
+ chiL_ps ,
135
+ (
136
+ chiL_ps [:, - 1 ]
137
+ + jnp .array (
138
+ [
139
+ chiL_ps [2 , - 1 ],
140
+ - jnp .sin (chiL_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
141
+ jnp .cos (chiL_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
142
+ ]
143
+ )
144
+ ).reshape (3 , 1 ),
145
+ ],
146
+ axis = 1 ,
147
+ )
148
+ curve_rod_left = onp .array (batched_chi2u (chiL_ps ))
149
+ cv2 .polylines (
150
+ img ,
151
+ [curve_rod_left ],
152
+ isClosed = False ,
153
+ color = rod_color ,
154
+ thickness = 10 ,
155
+ # thickness=2*int(ppm * params["rout"].mean(axis=0)[0])
156
+ )
157
+ # add the first point of the proximal cap and the last point of the distal cap
158
+ chiR_ps = jnp .concatenate (
159
+ [
160
+ (chiR_ps [:, 0 ] - jnp .array ([0.0 , 0.0 , params ["lpc" ][0 ]])).reshape (3 , 1 ),
161
+ chiR_ps ,
162
+ (
163
+ chiR_ps [:, - 1 ]
164
+ + jnp .array (
165
+ [
166
+ chiR_ps [2 , - 1 ],
167
+ - jnp .sin (chiR_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
168
+ jnp .cos (chiR_ps [2 , - 1 ]) * params ["ldc" ][- 1 ],
169
+ ]
170
+ )
171
+ ).reshape (3 , 1 ),
172
+ ],
173
+ axis = 1 ,
174
+ )
175
+ curve_rod_right = onp .array (batched_chi2u (chiR_ps ))
176
+ cv2 .polylines (img , [curve_rod_right ], isClosed = False , color = rod_color , thickness = 10 )
177
+
178
+ # draw the platform
179
+ for i in range (chip_ps .shape [0 ]):
180
+ # iterate over the platforms
181
+ platform_R = jnp .array (
182
+ [
183
+ [jnp .cos (chip_ps [i , 0 ]), - jnp .sin (chip_ps [i , 0 ])],
184
+ [jnp .sin (chip_ps [i , 0 ]), jnp .cos (chip_ps [i , 0 ])],
185
+ ]
186
+ ) # rotation matrix for the platform
187
+ platform_llc = chip_ps [i , 1 :] + platform_R @ jnp .array (
188
+ [
189
+ - params ["pcudim" ][i , 1 ] / 2 , # go half the width to the left
190
+ - params ["pcudim" ][i , 2 ] / 2 , # go half the height down
191
+ ]
192
+ ) # lower left corner of the platform
193
+ platform_ulc = chip_ps [i , 1 :] + platform_R @ jnp .array (
194
+ [
195
+ - params ["pcudim" ][i , 1 ] / 2 , # go half the width to the left
196
+ + params ["pcudim" ][i , 2 ] / 2 , # go half the height down
197
+ ]
198
+ ) # upper left corner of the platform
199
+ platform_urc = chip_ps [i , 1 :] + platform_R @ jnp .array (
200
+ [
201
+ + params ["pcudim" ][i , 1 ] / 2 , # go half the width to the left
202
+ + params ["pcudim" ][i , 2 ] / 2 , # go half the height down
203
+ ]
204
+ ) # upper right corner of the platform
205
+ platform_lrc = chip_ps [i , 1 :] + platform_R @ jnp .array (
206
+ [
207
+ + params ["pcudim" ][i , 1 ] / 2 , # go half the width to the left
208
+ - params ["pcudim" ][i , 2 ] / 2 , # go half the height down
209
+ ]
210
+ ) # lower right corner of the platform
211
+ platform_curve = jnp .stack (
212
+ [platform_llc , platform_ulc , platform_urc , platform_lrc , platform_llc ],
213
+ axis = 1 ,
214
+ )
215
+ # cv2.polylines(img, [onp.array(batched_chi2u(platform_curve))], isClosed=True, color=platform_color, thickness=5)
216
+ cv2 .fillPoly (
217
+ img , [onp .array (batched_chi2u (platform_curve ))], color = platform_color
218
+ )
219
+
220
+ return img
221
+
222
+
223
+ if __name__ == "__main__" :
224
+ num_segments = 1
225
+ num_rods_per_segment = 2
226
+
227
+ # filepath to symbolic expressions
228
+ sym_exp_filepath = (
229
+ Path (jsrm .__file__ ).parent
230
+ / "symbolic_expressions"
231
+ / f"planar_hsa_ns-{ num_segments } _nrs-{ num_rods_per_segment } .dill"
232
+ )
233
+
234
+ # activate all strains (i.e. bending, shear, and axial)
235
+ strain_selector = jnp .ones ((3 * num_segments ,), dtype = bool )
236
+ consider_hysteresis = True
237
+
238
+ params = PARAMS_FPU_HYSTERESIS_CONTROL if consider_hysteresis else PARAMS_FPU_CONTROL
239
+ # increase damping for simulation stability
240
+ params ["zetab" ] = 5 * params ["zetab" ]
241
+ params ["zetash" ] = 5 * params ["zetash" ]
242
+ params ["zetaa" ] = 5 * params ["zetaa" ]
243
+
244
+ # ======================================================
245
+ # Robot initialization
246
+ # ======================================================
247
+ robot = PlanarHSA (
248
+ sym_exp_filepath = sym_exp_filepath ,
249
+ params = params ,
250
+ strain_selector = strain_selector ,
251
+ consider_hysteresis = consider_hysteresis ,
252
+ )
253
+
254
+ # =====================================================
255
+ # Simulation upon time
256
+ # =====================================================
257
+ # Initial configuration
258
+ q0 = jnp .array ([jnp .pi , 0.0 , 0.0 ])
259
+ # Initial velocities
260
+ qd0 = jnp .zeros_like (q0 )
261
+ # Motor actuation angles
262
+ phi = jnp .array ([jnp .pi , jnp .pi / 2 ])
263
+
264
+ # Displaying the image
265
+ window_name = f"Planar HSA with { num_segments } segments"
266
+ img = draw_robot (
267
+ robot ,
268
+ q = q0 ,
269
+ )
270
+
271
+
272
+ # Simulation time parameters
273
+ t0 = 0.0
274
+ t1 = 5.0
275
+ dt = 5e-5 # time step
276
+ skip_step = 100 # how many time steps to skip in between video frames
277
+
278
+ # Solver
279
+ solver = Tsit5 ()
0 commit comments