@@ -24,11 +24,6 @@ def factory(
24
24
if segment_actuation_selector is None :
25
25
segment_actuation_selector = jnp .ones (num_segments , dtype = bool )
26
26
27
- # number of system inputs
28
- actuation_dim = segment_actuation_selector .sum ()
29
- actuation_dim = 2 * num_segments
30
- actuation_basis = jnp .eye (actuation_dim )
31
-
32
27
def actuation_mapping_fn (
33
28
forward_kinematics_fn : Callable ,
34
29
jacobian_fn : Callable ,
@@ -112,12 +107,17 @@ def compute_A_d_wrt_xi_i(i: Array, l_i: Array, xi_i: Array) -> Array:
112
107
113
108
return A_sm
114
109
110
+ # compute the actuation matrix for all segments
111
+ # will have shape (num_segments, n_xi, num_segment_tendons)
115
112
A = vmap (compute_actuation_matrix_for_segment , in_axes = (0 , 0 ), out_axes = 0 )(
116
113
segment_indices , params ["d" ],
117
- ).transpose ((1 , 0 , 2 )).reshape (xi .shape [0 ], - 1 )
114
+ )
115
+
116
+ # deactivate the actuation for some segments
117
+ A = A [segment_actuation_selector ]
118
118
119
- # apply the actuation_basis
120
- A = A @ actuation_basis
119
+ # reshape the actuation matrix to have shape (n_xi, n_act)
120
+ A = A . transpose (( 1 , 0 , 2 )). reshape ( xi . shape [ 0 ], - 1 )
121
121
122
122
return A
123
123
0 commit comments