Skip to content

Commit 303c370

Browse files
committed
Make operational_space_dynamical_matrices_fn jittable again by specifying operational_space_selector as a tuple
1 parent 3067a41 commit 303c370

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/jsrm/systems/planar_pcs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dill
22
from jax import Array, jit, lax, vmap
33
from jax import numpy as jnp
4+
import numpy as onp
45
import sympy as sp
56
from pathlib import Path
67
from typing import Callable, Dict, List, Tuple, Union
@@ -404,7 +405,7 @@ def operational_space_dynamical_matrices_fn(
404405
s: Array,
405406
B: Array,
406407
C: Array,
407-
operational_space_selector = jnp.array([True, True, True]),
408+
operational_space_selector: Tuple = (True, True, True),
408409
eps: float = 1e4 * global_eps,
409410
) -> Tuple[Array, Array, Array, Array, Array]:
410411
"""
@@ -418,8 +419,8 @@ def operational_space_dynamical_matrices_fn(
418419
s: point coordinate along the robot in the interval [0, L].
419420
B: inertia matrix in the generalized coordinates of shape (n_q, n_q)
420421
C: coriolis matrix derived with Christoffer symbols in the generalized coordinates of shape (n_q, n_q)
421-
operational_space_selector: boolean array of shape (3,) to select the operational space variables.
422-
For examples, jnp.array([True, True, False]) selects only the positional components of the operational space.
422+
operational_space_selector: tuple of shape (3,) to select the operational space variables.
423+
For examples, (True, True, False) selects only the positional components of the operational space.
423424
eps: small number to avoid singularities (e.g., division by zero)
424425
Returns:
425426
Lambda: inertia matrix in the operational space of shape (3, 3)
@@ -443,6 +444,9 @@ def operational_space_dynamical_matrices_fn(
443444
# convert the dictionary of parameters to a list, which we can pass to the lambda function
444445
params_for_lambdify = select_params_for_lambdify_fn(params)
445446

447+
# make operational_space_selector a boolean array
448+
operational_space_selector = onp.array(operational_space_selector, dtype=bool)
449+
446450
# Jacobian and its time-derivative
447451
J = lax.switch(
448452
segment_idx, J_lambda_sms, *params_for_lambdify, *xi_epsed, s_segment

0 commit comments

Comments
 (0)