1
+ import jax
2
+
3
+ from jsrm .systems .planar_pcs import PlanarPCSNum
4
+ import jax .numpy as jnp
5
+
6
+ from typing import Callable , Dict
7
+ from jax import Array
8
+
9
+ import numpy as onp
10
+
11
+ import matplotlib .pyplot as plt
12
+ from matplotlib .animation import FuncAnimation
13
+ from IPython .display import HTML
14
+
15
+ from diffrax import Tsit5
16
+
17
+ from functools import partial
18
+
19
+ jax .config .update ("jax_enable_x64" , True ) # double precision
20
+ jnp .set_printoptions (
21
+ threshold = jnp .inf ,
22
+ linewidth = jnp .inf ,
23
+ formatter = {'float_kind' : lambda x : '0' if x == 0 else f'{ x :.2e} ' }
24
+ )
25
+
26
+ def draw_robot_curve_class (
27
+ batched_forward_kinematics_fn : Callable ,
28
+ params : Dict [str , Array ],
29
+ q : Array ,
30
+ width : int ,
31
+ height : int ,
32
+ num_points : int = 50 ,
33
+ ):
34
+ h , w = height , width
35
+ ppm = h / (2.0 * jnp .sum (params ["l" ]))
36
+ s_ps = jnp .linspace (0 , jnp .sum (params ["l" ]), num_points )
37
+ chi_ps = batched_forward_kinematics_fn (q , s_ps )
38
+
39
+ # Position du robot dans les coordonnées pixel
40
+ curve_origin = onp .array ([w // 2 , 0.1 * h ])
41
+ curve = onp .array ((curve_origin [:, None ] + chi_ps [1 :, :] * ppm ), dtype = onp .float32 ).T
42
+ curve [:, 1 ] = h - curve [:, 1 ]
43
+
44
+ return curve # (N, 2)
45
+
46
+ def plot_robot_matplotlib (
47
+ batched_forward_kinematics_fn : Callable ,
48
+ params : Dict [str , Array ],
49
+ q : Array ,
50
+ width : int = 500 ,
51
+ height : int = 500 ,
52
+ num_points : int = 50 ,
53
+ show : bool = False ,
54
+ ):
55
+ fig , ax = plt .subplots ()
56
+ ax .set_xlim (0 , width )
57
+ ax .set_ylim (0 , height )
58
+ ax .invert_yaxis ()
59
+ line , = ax .plot ([], [], lw = 4 , color = "blue" )
60
+ curve = draw_robot_curve_class (batched_forward_kinematics_fn , params , q , width , height , num_points )
61
+ line .set_data (curve [:, 0 ], curve [:, 1 ])
62
+
63
+ if show :
64
+ plt .show (fig )
65
+
66
+ return fig
67
+
68
+ def animate_robot_matplotlib (
69
+ batched_forward_kinematics_fn : Callable ,
70
+ params : Dict [str , Array ],
71
+ t_list : Array , # shape (T,)
72
+ q_list : Array , # shape (T, DOF)
73
+ width : int = 500 ,
74
+ height : int = 500 ,
75
+ num_points : int = 50 ,
76
+ interval : int = 50 ,
77
+ boolshow : bool = True ,
78
+ ):
79
+ fig , ax = plt .subplots ()
80
+ ax .set_xlim (0 , width )
81
+ ax .set_ylim (0 , height )
82
+ ax .invert_yaxis ()
83
+ line , = ax .plot ([], [], lw = 4 , color = "blue" )
84
+ title_text = ax .set_title ("t = 0.00 s" )
85
+
86
+ def init ():
87
+ line .set_data ([], [])
88
+ title_text .set_text ("t = 0.00 s" )
89
+ return line , title_text
90
+
91
+ def update (frame_idx ):
92
+ q = q_list [frame_idx ]
93
+ t = t_list [frame_idx ]
94
+ curve = draw_robot_curve_class (batched_forward_kinematics_fn , params , q , width , height , num_points )
95
+ line .set_data (curve [:, 0 ], curve [:, 1 ])
96
+ title_text .set_text (f"t = { t :.2f} s" )
97
+ return line , title_text
98
+
99
+ ani = FuncAnimation (
100
+ fig ,
101
+ update ,
102
+ frames = len (q_list ),
103
+ init_func = init ,
104
+ blit = False ,
105
+ interval = interval )
106
+
107
+ if boolshow :
108
+ plt .show ()
109
+ plt .close (fig )
110
+ return HTML (ani .to_jshtml ())
111
+
112
+ if __name__ == "__main__" :
113
+ num_segments = 2
114
+ rho = 1070 * jnp .ones ((num_segments ,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
115
+ params = {
116
+ "th0" : jnp .array (0.0 ), # initial orientation angle [rad]
117
+ "l" : 1e-1 * jnp .ones ((num_segments ,)),
118
+ "r" : 2e-2 * jnp .ones ((num_segments ,)),
119
+ "rho" : rho ,
120
+ "g" : jnp .array ([0.0 , 9.81 ]),
121
+ "E" : 2e3 * jnp .ones ((num_segments ,)), # Elastic modulus [Pa]
122
+ "G" : 1e3 * jnp .ones ((num_segments ,)), # Shear modulus [Pa]
123
+ }
124
+ params ["D" ] = 1e-3 * jnp .diag (
125
+ (jnp .repeat (
126
+ jnp .array ([[1e0 , 1e3 , 1e3 ]]), num_segments , axis = 0
127
+ ) * params ["l" ][:, None ]).flatten ()
128
+ )
129
+
130
+ # ======================================================
131
+ # Robot initialization
132
+ # ======================================================
133
+ robot = PlanarPCSNum (
134
+ num_segments = num_segments ,
135
+ params = params ,
136
+ order_gauss = 5 ,
137
+ )
138
+
139
+ # =====================================================
140
+ # Simulation upon time
141
+ # =====================================================
142
+ # Initial configuration
143
+ q0 = jnp .repeat (jnp .array ([5.0 * jnp .pi , 0.1 , 0.2 ])[None , :], num_segments , axis = 0 ).flatten ()
144
+ # Initial velocities
145
+ qd0 = jnp .zeros_like (q0 )
146
+
147
+ # Actuation parameters
148
+ tau = jnp .zeros_like (q0 )
149
+ # WARNING: actuation_args need to be a tuple, even if it contains only one element
150
+ actuation_args = (tau ,)
151
+
152
+ # Simulation time parameters
153
+ t0 = 0.0
154
+ t1 = 2.0
155
+ dt = 1e-4
156
+ skip_step = 100 # how many time steps to skip in between video frames
157
+
158
+ # Solver
159
+ solver = Tsit5 () # Runge-Kutta 5(4) method
160
+
161
+ ts , q_ts , q_d_ts = robot .resolve_upon_time (
162
+ q0 = q0 ,
163
+ qd0 = qd0 ,
164
+ actuation_args = actuation_args ,
165
+ t0 = t0 ,
166
+ t1 = t1 ,
167
+ dt = dt ,
168
+ skip_steps = skip_step ,
169
+ max_steps = None
170
+ )
171
+
172
+ # =====================================================
173
+ # End-effector position upon time
174
+ # =====================================================
175
+ forward_kinematics_end_effector = jax .jit (partial (
176
+ robot .forward_kinematics_fn ,
177
+ s = jnp .sum (robot .l ) # end-effector position
178
+ ))
179
+ chi_ee_ts = jax .vmap (forward_kinematics_end_effector )(q_ts )
180
+
181
+ plt .figure ()
182
+ plt .plot (ts , chi_ee_ts [:, 1 ], label = "End-effector x [m]" )
183
+ plt .plot (ts , chi_ee_ts [:, 2 ], label = "End-effector y [m]" )
184
+ plt .xlabel ("Time [s]" )
185
+ plt .ylabel ("End-effector position [m]" )
186
+ plt .legend ()
187
+ plt .grid (True )
188
+ plt .box (True )
189
+ plt .tight_layout ()
190
+ plt .show ()
191
+
192
+ plt .figure ()
193
+ plt .scatter (chi_ee_ts [:, 1 ], chi_ee_ts [:, 2 ], c = ts , cmap = "viridis" )
194
+ plt .axis ("equal" )
195
+ plt .grid (True )
196
+ plt .xlabel ("End-effector x [m]" )
197
+ plt .ylabel ("End-effector y [m]" )
198
+ plt .colorbar (label = "Time [s]" )
199
+ plt .tight_layout ()
200
+ plt .show ()
201
+
202
+ # =====================================================
203
+ # Energy computation upon time
204
+ # =====================================================
205
+ U_ts = jax .vmap (jax .jit (partial (robot .potential_energy )))(q_ts )
206
+ T_ts = jax .vmap (jax .jit (partial (robot .kinetic_energy )))(q_ts , q_d_ts )
207
+
208
+ plt .figure ()
209
+ plt .plot (ts , U_ts , label = "Potential Energy" )
210
+ plt .plot (ts , T_ts , label = "Kinetic Energy" )
211
+ plt .xlabel ("Time (s)" )
212
+ plt .ylabel ("Energy (J)" )
213
+ plt .legend ()
214
+ plt .title ("Energy over Time" )
215
+ plt .grid (True )
216
+ plt .box (True )
217
+ plt .tight_layout ()
218
+ plt .show ()
219
+
220
+ # =====================================================
221
+ # Plot the robot configuration upon time
222
+ # =====================================================
223
+ animate_robot_matplotlib (
224
+ batched_forward_kinematics_fn = jax .vmap (robot .forward_kinematics_fn , in_axes = (None , 0 ), out_axes = - 1 ),
225
+ params = params ,
226
+ t_list = ts , # shape (T,)
227
+ q_list = q_ts , # shape (T, DOF)
228
+ width = 700 ,
229
+ height = 700 ,
230
+ num_points = 50 ,
231
+ interval = 100 , #ms
232
+ )
0 commit comments