Skip to content

Commit 81e5895

Browse files
committed
Fix jit decorators to apply JIT-compilation to
last level
1 parent 6952264 commit 81e5895

14 files changed

+474
-2150
lines changed

examples/benchmark_planar_pcs_num.py renamed to examples/benchmark_planar_pcs.py

Lines changed: 440 additions & 1044 deletions
Large diffs are not rendered by default.

examples/simulate_planar_pcs.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def draw_robot(
126126
tau = jnp.zeros_like(q0) # torques
127127

128128
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
129+
# jit the ODE function
130+
ode_fn = jax.jit(ode_fn)
129131
term = ODETerm(ode_fn)
130132

131133
sol = diffeqsolve(
@@ -145,10 +147,14 @@ def draw_robot(
145147
# the evolution of the generalized velocities
146148
q_d_ts = sol.ys[:, n_q:]
147149

150+
s_max = jnp.array([jnp.sum(params["l"])])
151+
152+
forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
153+
forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector)
154+
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)
155+
148156
# evaluate the forward kinematics along the trajectory
149-
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
150-
params, q_ts, jnp.array([jnp.sum(params["l"])])
151-
)
157+
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
152158
# plot the configuration vs time
153159
plt.figure()
154160
for segment_idx in range(num_segments):
@@ -202,10 +208,10 @@ def draw_robot(
202208

203209
# plot the energy along the trajectory
204210
kinetic_energy_fn_vmapped = vmap(
205-
partial(auxiliary_fns["kinetic_energy_fn"], params)
211+
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
206212
)
207213
potential_energy_fn_vmapped = vmap(
208-
partial(auxiliary_fns["potential_energy_fn"], params)
214+
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)
209215
)
210216
U_ts = potential_energy_fn_vmapped(q_ts)
211217
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)

src/jsrm/integration.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax import Array, jit
1+
from jax import Array
22
from typing import Callable, Dict
33

44
from jsrm.systems import euler_lagrangian
@@ -26,7 +26,6 @@ def ode_factory(
2626
ode_fn: ODE function of the form ode_fn(t, x) -> x_dot
2727
"""
2828

29-
@jit
3029
def ode_fn(t: float, x: Array, *args) -> Array:
3130
"""
3231
ODE of the dynamical Lagrangian system.
@@ -69,7 +68,6 @@ def ode_with_forcing_factory(
6968
ode_fn: ODE function of the form ode_fn(t, x, tau) -> x_dot
7069
"""
7170

72-
@jit
7371
def ode_fn(
7472
t: float,
7573
x: Array,

src/jsrm/math_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from jax import numpy as jnp
2-
from jax import Array, lax, jit
2+
from jax import Array, lax
33

44
def blk_diag(
55
a: Array

src/jsrm/systems/planar_pcs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
)
1414
from jsrm.math_utils import blk_diag
1515

16-
1716
def factory(
1817
filepath: Union[str, Path],
1918
strain_selector: Array = None,
@@ -54,7 +53,6 @@ def factory(
5453
# symbols for robot parameters
5554
params_syms = sym_exps["params_syms"]
5655

57-
@jit
5856
def select_params_for_lambdify_fn(params: Dict[str, Array]) -> List[Array]:
5957
"""
6058
Select the parameters for lambdify
@@ -147,7 +145,6 @@ def select_params_for_lambdify_fn(params: Dict[str, Array]) -> List[Array]:
147145
compute_planar_stiffness_matrix
148146
)
149147

150-
@jit
151148
def apply_eps_to_bend_strains(xi: Array, _eps: float) -> Array:
152149
"""
153150
Add a small number to the bending strain to avoid singularities
@@ -263,7 +260,6 @@ def actuation_mapping_fn(
263260

264261
return A
265262

266-
@jit
267263
def forward_kinematics_fn(
268264
params: Dict[str, Array], q: Array, s: Array, eps: float = global_eps
269265
) -> Array:
@@ -297,7 +293,6 @@ def forward_kinematics_fn(
297293

298294
return chi
299295

300-
@jit
301296
def jacobian_fn(
302297
params: Dict[str, Array], q: Array, s: Array, eps: float = global_eps
303298
) -> Array:
@@ -333,7 +328,6 @@ def jacobian_fn(
333328

334329
return J
335330

336-
@jit
337331
def dynamical_matrices_fn(
338332
params: Dict[str, Array], q: Array, q_d: Array, eps: float = 1e4 * global_eps
339333
) -> Tuple[Array, Array, Array, Array, Array, Array]:

0 commit comments

Comments
 (0)