Skip to content

Commit a66de0b

Browse files
committed
Fix bugs for multi-segment case
1 parent 2d0a009 commit a66de0b

File tree

3 files changed

+13
-23
lines changed

3 files changed

+13
-23
lines changed

examples/simulate_tendon_actuated_planar_pcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def draw_robot(
123123
print("A =\n", A)
124124

125125
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
126-
u = 1e0 * jnp.array([1.0, 1.0])[None].repeat(num_segments, axis=0).flatten() # tendon tensions
126+
u = jnp.array([1.0, 1.0])[None].repeat(num_segments, axis=0).flatten() # tendon tensions
127+
# u = 2e-1 * jnp.array([2.0, 0.0, 0.0, 1.0])
127128
print("u =\n", u)
128129

129130
ode_fn = ode_factory(dynamical_matrices_fn, params, u)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ name = "jsrm" # Required
1717
#
1818
# For a discussion on single-sourcing the version, see
1919
# https://packaging.python.org/guides/single-sourcing-package-version/
20-
version = "0.0.14" # Required
20+
version = "0.0.15" # Required
2121

2222
# This is a one-line description or tagline of what your project does. This
2323
# corresponds to the "Summary" metadata field:

src/jsrm/systems/tendon_actuated_planar_pcs.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__all__ = ["factory", "stiffness_fn"]
2-
from jax import Array, lax, vmap
2+
from jax import Array, debug, lax, vmap
33
import jax.numpy as jnp
44
from jsrm.math_utils import blk_diag
55
import numpy as onp
@@ -24,17 +24,10 @@ def factory(
2424
if segment_actuation_selector is None:
2525
segment_actuation_selector = jnp.ones(num_segments, dtype=bool)
2626

27-
# number of input pressures
28-
actuation_dim = segment_actuation_selector.sum() * 2
29-
30-
# matrix that maps the (possibly) underactuated actuation space to a full actuation space
31-
actuation_basis = jnp.zeros((2 * num_segments, actuation_dim))
32-
actuation_basis_cumsum = jnp.cumsum(segment_actuation_selector)
33-
for i in range(num_segments):
34-
j = int(actuation_basis_cumsum[i].item()) - 1
35-
if segment_actuation_selector[i].item() is True:
36-
actuation_basis = actuation_basis.at[2 * i, j].set(1.0)
37-
actuation_basis = actuation_basis.at[2 * i + 1, j + 1].set(1.0)
27+
# number of system inputs
28+
actuation_dim = segment_actuation_selector.sum()
29+
actuation_dim = 2 * num_segments
30+
actuation_basis = jnp.eye(actuation_dim)
3831

3932
def actuation_mapping_fn(
4033
forward_kinematics_fn: Callable,
@@ -106,26 +99,22 @@ def compute_A_d_wrt_xi_i(i: Array, l_i: Array, xi_i: Array) -> Array:
10699
l_i * xi_i[2] * (1 + d * xi_i[0]) / sigma_norm,
107100
])
108101
return jnp.where(
109-
i <= segment_idx,
102+
i * jnp.ones((3, )) <= segment_idx * jnp.ones((3, )),
110103
A_d_wrt_xi_i,
111104
jnp.zeros_like(A_d_wrt_xi_i)
112105
)
113106

114-
A_d = vmap(compute_A_d_wrt_xi_i)(segment_indices, l, xi.reshape(-1, 3))
107+
A_d = vmap(compute_A_d_wrt_xi_i)(segment_indices, l, xi.reshape(-1, 3)).reshape(-1)
115108

116-
# concatenate the derivatives for all segments
117-
A_d = jnp.concatenate(A_d, axis=0)
118109
return A_d
119110

120-
A_sm = vmap(compute_A_d, in_axes=0, out_axes=1)(d_sm)
111+
A_sm = vmap(compute_A_d, in_axes=0, out_axes=-1)(d_sm)
121112

122113
return A_sm
123114

124-
A_sms = vmap(compute_actuation_matrix_for_segment)(
115+
A = vmap(compute_actuation_matrix_for_segment, in_axes=(0, 0), out_axes=0)(
125116
segment_indices, params["d"],
126-
)
127-
# concatenate the actuation matrices for all tendons
128-
A = jnp.concatenate(A_sms, axis=1)
117+
).transpose((1, 0, 2)).reshape(xi.shape[0], -1)
129118

130119
# apply the actuation_basis
131120
A = A @ actuation_basis

0 commit comments

Comments
 (0)