1
1
__all__ = ["factory" , "stiffness_fn" ]
2
- from jax import Array , lax , vmap
2
+ from jax import Array , debug , lax , vmap
3
3
import jax .numpy as jnp
4
4
from jsrm .math_utils import blk_diag
5
5
import numpy as onp
@@ -24,17 +24,10 @@ def factory(
24
24
if segment_actuation_selector is None :
25
25
segment_actuation_selector = jnp .ones (num_segments , dtype = bool )
26
26
27
- # number of input pressures
28
- actuation_dim = segment_actuation_selector .sum () * 2
29
-
30
- # matrix that maps the (possibly) underactuated actuation space to a full actuation space
31
- actuation_basis = jnp .zeros ((2 * num_segments , actuation_dim ))
32
- actuation_basis_cumsum = jnp .cumsum (segment_actuation_selector )
33
- for i in range (num_segments ):
34
- j = int (actuation_basis_cumsum [i ].item ()) - 1
35
- if segment_actuation_selector [i ].item () is True :
36
- actuation_basis = actuation_basis .at [2 * i , j ].set (1.0 )
37
- actuation_basis = actuation_basis .at [2 * i + 1 , j + 1 ].set (1.0 )
27
+ # number of system inputs
28
+ actuation_dim = segment_actuation_selector .sum ()
29
+ actuation_dim = 2 * num_segments
30
+ actuation_basis = jnp .eye (actuation_dim )
38
31
39
32
def actuation_mapping_fn (
40
33
forward_kinematics_fn : Callable ,
@@ -106,26 +99,22 @@ def compute_A_d_wrt_xi_i(i: Array, l_i: Array, xi_i: Array) -> Array:
106
99
l_i * xi_i [2 ] * (1 + d * xi_i [0 ]) / sigma_norm ,
107
100
])
108
101
return jnp .where (
109
- i <= segment_idx ,
102
+ i * jnp . ones (( 3 , )) <= segment_idx * jnp . ones (( 3 , )) ,
110
103
A_d_wrt_xi_i ,
111
104
jnp .zeros_like (A_d_wrt_xi_i )
112
105
)
113
106
114
- A_d = vmap (compute_A_d_wrt_xi_i )(segment_indices , l , xi .reshape (- 1 , 3 ))
107
+ A_d = vmap (compute_A_d_wrt_xi_i )(segment_indices , l , xi .reshape (- 1 , 3 )). reshape ( - 1 )
115
108
116
- # concatenate the derivatives for all segments
117
- A_d = jnp .concatenate (A_d , axis = 0 )
118
109
return A_d
119
110
120
- A_sm = vmap (compute_A_d , in_axes = 0 , out_axes = 1 )(d_sm )
111
+ A_sm = vmap (compute_A_d , in_axes = 0 , out_axes = - 1 )(d_sm )
121
112
122
113
return A_sm
123
114
124
- A_sms = vmap (compute_actuation_matrix_for_segment )(
115
+ A = vmap (compute_actuation_matrix_for_segment , in_axes = ( 0 , 0 ), out_axes = 0 )(
125
116
segment_indices , params ["d" ],
126
- )
127
- # concatenate the actuation matrices for all tendons
128
- A = jnp .concatenate (A_sms , axis = 1 )
117
+ ).transpose ((1 , 0 , 2 )).reshape (xi .shape [0 ], - 1 )
129
118
130
119
# apply the actuation_basis
131
120
A = A @ actuation_basis
0 commit comments