@@ -131,6 +131,8 @@ def draw_robot(
131
131
ode_fn = ode_factory (dynamical_matrices_fn , params , tau )
132
132
# jit the ODE function
133
133
ode_fn = jax .jit (ode_fn )
134
+ # jit the ODE function
135
+ ode_fn = jax .jit (ode_fn )
134
136
term = ODETerm (ode_fn )
135
137
136
138
sol = diffeqsolve (
@@ -150,6 +152,13 @@ def draw_robot(
150
152
# the evolution of the generalized velocities
151
153
q_d_ts = sol .ys [:, n_q :]
152
154
155
+ s_max = jnp .array ([jnp .sum (params ["l" ])])
156
+
157
+ forward_kinematics_fn_end_effector = partial (forward_kinematics_fn , params , s = s_max )
158
+ forward_kinematics_fn_end_effector = jax .jit (forward_kinematics_fn_end_effector )
159
+ forward_kinematics_fn_end_effector = vmap (forward_kinematics_fn_end_effector )
160
+
161
+
153
162
s_max = jnp .array ([jnp .sum (params ["l" ])])
154
163
155
164
forward_kinematics_fn_end_effector = partial (forward_kinematics_fn , params , s = s_max )
@@ -158,6 +167,7 @@ def draw_robot(
158
167
159
168
# evaluate the forward kinematics along the trajectory
160
169
chi_ee_ts = forward_kinematics_fn_end_effector (q_ts )
170
+ chi_ee_ts = forward_kinematics_fn_end_effector (q_ts )
161
171
# plot the configuration vs time
162
172
plt .figure ()
163
173
for segment_idx in range (num_segments ):
@@ -215,9 +225,11 @@ def draw_robot(
215
225
# plot the energy along the trajectory
216
226
kinetic_energy_fn_vmapped = vmap (
217
227
partial (jax .jit (auxiliary_fns ["kinetic_energy_fn" ]), params )
228
+ partial (jax .jit (auxiliary_fns ["kinetic_energy_fn" ]), params )
218
229
)
219
230
potential_energy_fn_vmapped = vmap (
220
231
partial (jax .jit (auxiliary_fns ["potential_energy_fn" ]), params )
232
+ partial (jax .jit (auxiliary_fns ["potential_energy_fn" ]), params )
221
233
)
222
234
U_ts = potential_energy_fn_vmapped (q_ts )
223
235
T_ts = kinetic_energy_fn_vmapped (q_ts , q_d_ts )
0 commit comments