17
17
18
18
import time
19
19
20
+ # ==================================================
21
+ # Simulation parameters
22
+ # ==================================================
20
23
num_segments = 2
21
24
22
25
type_of_derivation = "numeric" #"symbolic" #
26
+ type_of_integration = "trapezoid" #"gauss" #
27
+
28
+ if type_of_derivation == "numeric" :
29
+ if type_of_integration == "gauss" :
30
+ param_integration = 30
31
+ elif type_of_integration == "trapezoid" :
32
+ param_integration = 100
23
33
24
34
if type_of_derivation == "symbolic" :
25
35
# filepath to symbolic expressions
59
69
# ======================
60
70
# set simulation parameters
61
71
dt = 1e-4 # time step
62
- ts = jnp .arange (0.0 , 2 , dt ) # time steps
72
+ ts = jnp .arange (0.0 , 1 , dt ) # time steps
63
73
skip_step = 10 # how many time steps to skip in between video frames
64
74
video_ts = ts [::skip_step ] # time steps for video
65
75
66
76
# video settings
67
- video_width , video_height = 700 , 700 # img height and width
68
- video_path = Path (__file__ ).parent / "videos" / f"planar_pcs_ns-{ num_segments } -{ ('symb' if type_of_derivation == 'symbolic' else 'num' )} .mp4"
77
+ video_width , video_height = 700 , 700 # img height and width"
78
+ video_path_parent = Path (__file__ ).parent / "videos"
79
+ video_path_parent .mkdir (parents = True , exist_ok = True )
80
+ extension = f"planar_pcs_ns-{ num_segments } -{ ('symb' if type_of_derivation == 'symbolic' else 'num' )} "
81
+ if type_of_derivation == "numeric" :
82
+ extension += f"-{ type_of_integration } -{ param_integration } "
83
+ elif type_of_derivation == "symbolic" :
84
+ extension += "-symb"
85
+ video_path = video_path_parent / f"{ extension } .mp4"
69
86
70
87
def draw_robot (
71
88
batched_forward_kinematics_fn : Callable ,
@@ -107,28 +124,33 @@ def draw_robot(
107
124
# ======================
108
125
figures_path_parent = Path (__file__ ).parent / "figures"
109
126
extension = f"planar_pcs_ns-{ num_segments } -{ ('symb' if type_of_derivation == 'symbolic' else 'num' )} "
127
+ if type_of_derivation == "numeric" :
128
+ extension += f"-{ type_of_integration } -{ param_integration } "
110
129
figures_path_parent .mkdir (parents = True , exist_ok = True )
111
130
112
131
if __name__ == "__main__" :
113
132
print ("Type of derivation:" , type_of_derivation )
133
+ if type_of_derivation == "numeric" :
134
+ print ("Type of integration:" , type_of_integration )
135
+ print ("Parameter for integration:" , param_integration )
114
136
print ("Number of segments:" , num_segments , "\n " )
115
137
116
138
print ("Importing the planar PCS model..." )
117
139
timer_start = time .time ()
118
- # import jsrm
119
140
if type_of_derivation == "symbolic" :
120
141
strain_basis , forward_kinematics_fn , dynamical_matrices_fn , auxiliary_fns = (
121
142
planar_pcs .factory (sym_exp_filepath , strain_selector )
122
143
)
123
144
124
145
elif type_of_derivation == "numeric" :
125
146
strain_basis , forward_kinematics_fn , dynamical_matrices_fn , auxiliary_fns = (
126
- planar_pcs_num .factory (num_segments , strain_selector )
147
+ planar_pcs_num .factory (num_segments , strain_selector , integration_type = type_of_integration , param_integration = param_integration )
127
148
)
128
149
else :
129
150
raise ValueError ("type_of_derivation must be 'symbolic' or 'numeric'" )
130
151
131
152
# jit the functions
153
+ print ("JIT-compiling the dynamical matrices function ..." )
132
154
dynamical_matrices_fn = jax .jit (partial (dynamical_matrices_fn ))
133
155
batched_forward_kinematics = vmap (
134
156
forward_kinematics_fn , in_axes = (None , None , 0 ), out_axes = - 1
@@ -148,42 +170,51 @@ def draw_robot(
148
170
timer_end = time .time ()
149
171
print (f"Evaluating the dynamical matrices took { timer_end - timer_start :.2f} seconds. \n " )
150
172
173
+ # Parameter for the simulation
151
174
x0 = jnp .concatenate ([q0 , jnp .zeros_like (q0 )]) # initial condition
152
175
tau = jnp .zeros_like (q0 ) # torques
153
176
154
- timer_start = time .time ()
155
177
ode_fn = ode_factory (dynamical_matrices_fn , params , tau )
156
178
term = ODETerm (ode_fn )
157
- timer_end = time .time ()
158
179
159
- print ("Solving the ODE..." )
180
+ # jit the functions
181
+ print ("JIT-compiling the ODE function..." )
182
+ diffeqsolve_fn = jax .jit (
183
+ partial (diffeqsolve ,
184
+ term ,
185
+ solver = Tsit5 (),
186
+ t0 = ts [0 ],
187
+ t1 = ts [- 1 ],
188
+ dt0 = dt ,
189
+ y0 = x0 ,
190
+ max_steps = None ,
191
+ saveat = SaveAt (ts = video_ts )))
192
+
193
+ print ("Solving the ODE for the first time (JIT-compilation)..." )
160
194
timer_start = time .time ()
161
- sol = diffeqsolve (
162
- term ,
163
- solver = Tsit5 (),
164
- t0 = ts [0 ],
165
- t1 = ts [- 1 ],
166
- dt0 = dt ,
167
- y0 = x0 ,
168
- max_steps = None ,
169
- saveat = SaveAt (ts = video_ts ),
170
- )
171
-
195
+ sol = diffeqsolve_fn ()
172
196
print ("sol.ys =\n " , sol .ys )
197
+ timer_end = time .time ()
173
198
# the evolution of the generalized coordinates
174
199
q_ts = sol .ys [:, :n_q ]
175
200
# the evolution of the generalized velocities
176
201
q_d_ts = sol .ys [:, n_q :]
202
+ print (f"Solving the ODE took { timer_end - timer_start :.2f} seconds. \n " )
177
203
204
+ print ("Solving the ODE for the second time (after JIT-compilation)..." )
205
+ timer_start = time .time ()
206
+ sol = diffeqsolve_fn ()
207
+ print ("sol.ys =\n " , sol .ys )
178
208
timer_end = time .time ()
179
- print (f"Solving the ODE took { timer_end - timer_start :.2f} seconds. \n " )
209
+ print (f"Solving the ODE for a second time took { timer_end - timer_start :.2f} seconds. \n " )
180
210
181
211
print ("Evaluating the forward kinematics..." )
182
212
timer_start = time .time ()
183
213
# evaluate the forward kinematics along the trajectory
184
214
chi_ee_ts = vmap (forward_kinematics_fn , in_axes = (None , 0 , None ))(
185
215
params , q_ts , jnp .array ([jnp .sum (params ["l" ])])
186
216
)
217
+ print ("chi_ee_ts =\n " , chi_ee_ts )
187
218
timer_end = time .time ()
188
219
print (f"Evaluating the forward kinematics took { timer_end - timer_start :.2f} seconds. " )
189
220
@@ -264,7 +295,6 @@ def draw_robot(
264
295
# plot the energy vs time
265
296
plt .figure ()
266
297
plt .title ("Energy vs Time" )
267
- plt .plot (video_ts , U_ts + T_ts , label = "Total energy" )
268
298
plt .plot (video_ts , U_ts , label = "Potential energy" )
269
299
plt .plot (video_ts , T_ts , label = "Kinetic energy" )
270
300
plt .xlabel ("Time [s]" )
@@ -283,7 +313,6 @@ def draw_robot(
283
313
timer_start = time .time ()
284
314
# create video
285
315
fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
286
- video_path .parent .mkdir (parents = True , exist_ok = True )
287
316
video = cv2 .VideoWriter (
288
317
str (video_path ),
289
318
fourcc ,
0 commit comments