Skip to content

Commit 2d0a009

Browse files
committed
Fix some bugs in the tendon actuated planar pcs implementation
1 parent 1b3d3a9 commit 2d0a009

File tree

6 files changed

+53
-54
lines changed

6 files changed

+53
-54
lines changed

examples/derive_planar_pcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jsrm
44
from jsrm.symbolic_derivation.planar_pcs import symbolically_derive_planar_pcs_model
55

6-
NUM_SEGMENTS = 2
6+
NUM_SEGMENTS = 1
77

88
if __name__ == "__main__":
99
sym_exp_filepath = (

examples/simulate_tendon_actuated_planar_pcs.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
3737
"d": 2e-2 * jnp.array([[1.0, -1.0]]).repeat(num_segments, axis=0), # distance of tendons from the central axis [m]
3838
}
39-
print("params d =\n", params["d"])
4039
params["D"] = 1e-3 * jnp.diag(
4140
(jnp.repeat(
4241
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
@@ -53,7 +52,7 @@
5352

5453
# set simulation parameters
5554
dt = 1e-4 # time step
56-
ts = jnp.arange(0.0, 2, dt) # time steps
55+
ts = jnp.arange(0.0, 10.0, dt) # time steps
5756
skip_step = 10 # how many time steps to skip in between video frames
5857
video_ts = ts[::skip_step] # time steps for video
5958

@@ -100,34 +99,34 @@ def draw_robot(
10099

101100
if __name__ == "__main__":
102101
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
103-
planar_pcs.factory(sym_exp_filepath, strain_selector)
102+
planar_pcs.factory(num_segments, sym_exp_filepath, strain_selector)
104103
)
104+
actuation_mapping_fn = auxiliary_fns["actuation_mapping_fn"]
105105
# jit the functions
106106
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
107107
batched_forward_kinematics = vmap(
108108
forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1
109109
)
110110

111-
# import matplotlib.pyplot as plt
112-
# plt.plot(chi_ps[0, :], chi_ps[1, :])
113-
# plt.axis("equal")
114-
# plt.grid(True)
115-
# plt.xlabel("x [m]")
116-
# plt.ylabel("y [m]")
117-
# plt.show()
118-
119-
# Displaying the image
120-
# window_name = f"Planar PCS with {num_segments} segments"
121-
# img = draw_robot(batched_forward_kinematics, params, q0, video_width, video_height)
122-
# cv2.namedWindow(window_name)
123-
# cv2.imshow(window_name, img)
124-
# cv2.waitKey()
125-
# cv2.destroyWindow(window_name)
111+
# test the actuation mapping function
112+
xi_eq = jnp.array([0.0, 0.0, 1.0])[None].repeat(num_segments, axis=0).flatten()
113+
B_xi = strain_basis
114+
# call the actuation mapping function
115+
A = actuation_mapping_fn(
116+
forward_kinematics_fn,
117+
auxiliary_fns["jacobian_fn"],
118+
params,
119+
B_xi,
120+
xi_eq,
121+
jnp.zeros_like(q0),
122+
)
123+
print("A =\n", A)
126124

127125
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
128-
tau = jnp.zeros_like(q0) # torques
126+
u = 1e0 * jnp.array([1.0, 1.0])[None].repeat(num_segments, axis=0).flatten() # tendon tensions
127+
print("u =\n", u)
129128

130-
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
129+
ode_fn = ode_factory(dynamical_matrices_fn, params, u)
131130
term = ODETerm(ode_fn)
132131

133132
sol = diffeqsolve(

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def symbolically_derive_planar_pcs_model(
192192
"r": r_syms,
193193
"rho": rho_syms,
194194
"g": g_syms,
195-
"d": d,
196195
},
197196
"state_syms": {
198197
"xi": xi_syms,
-4 Bytes
Binary file not shown.
-85 Bytes
Binary file not shown.

src/jsrm/systems/tendon_actuated_planar_pcs.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as onp
66
from typing import Callable, Dict, Optional, Tuple, Union
77

8-
from .planar_pcs import factory as planar_pcs_factory, stifffness_fn
8+
from .planar_pcs import factory as planar_pcs_factory
99

1010
def factory(
1111
num_segments: int,
@@ -61,6 +61,8 @@ def actuation_mapping_fn(
6161
l = params["l"]
6262
# map the configuration to the strains
6363
xi = xi_eq + B_xi @ q
64+
# segment indices
65+
segment_indices = jnp.arange(num_segments)
6466

6567
def compute_actuation_matrix_for_segment(
6668
segment_idx, d_sm: Array,
@@ -70,61 +72,60 @@ def compute_actuation_matrix_for_segment(
7072
We assume that each segment is actuated by num_segment_tendons that are routed at a distance of d from the segment's backbone,
7173
respectively, and attached to the segment's distal end. We assume that the motor is located at the base of the robot and that the
7274
tendons are routed through all proximal segments.
73-
The control inputs u1 and u2 are the tensions (i.e., forces) applied by the two tendons.
75+
The positive control inputs u1 and u2 are the tensions (i.e., forces) applied by the two tendons.
76+
At a straight configuration with a positive d1, a positive u1 and zero u2 should cause the bend negatively (to the right) and contract its length.
7477
7578
Args:
7679
segment_idx: index of the segment
7780
d_sm: distance of the tendons from the segment's backbone (shape: (num_segment_tendons,))
7881
Returns:
7982
A_sm: actuation matrix of shape (n_xi, num_segment_tendons)
8083
"""
81-
num_segment_tendons = d_sm.shape[0]
82-
83-
A_sm = []
84-
for j in range(num_segment_tendons):
85-
d = d_sm[j]
86-
84+
def compute_A_d(d: Array) -> Array:
8785
"""
88-
A_d = []
89-
for i in range(0, segment_idx + 1):
90-
# length of the i-th segment
91-
l_i = l[i]
92-
# strain of the i-th segment
93-
xi_i = xi[3 * i:3 * (i + 1)]
94-
A_d.append(jnp.array([
95-
d * l_i * jnp.sqrt(xi_i[1]**2 + xi_i[2]**2), # the derivative of the tendon length with respect to the bending strain of the i-th segment
96-
l_i * xi_i[1] * (1 + d * xi_i[0]) / jnp.sqrt(xi_i[1]**2 + xi_i[2]**2), # the derivative of the tendon length with respect to the shear strain of the i-th segment
97-
l_i * xi_i[2] * (1 + d * xi_i[0]) / jnp.sqrt(xi_i[1]**2 + xi_i[2]**2), # the derivative of the tendon length with respect to the axial strain of the i-th segment
98-
]))
86+
Compute the actuation matrix for a single actuator/tendon with respect to the soft robot's strains.
87+
Args:
88+
d: distance of the tendon from the centerline
89+
Returns:
90+
A_d: actuation matrix of shape (n_xi, ) where n_xi is the number of strains
9991
"""
100-
def compute_A_d(l_i: Array, xi_i: Array) -> Array:
101-
print("l_i.shape", l_i.shape, "xi_i.shape", xi_i.shape)
92+
def compute_A_d_wrt_xi_i(i: Array, l_i: Array, xi_i: Array) -> Array:
93+
"""
94+
Compute the actuation matrix for a single actuator with respect to the strains of a single segment.
95+
Args:
96+
i: index of the segment
97+
l_i: length of the segment
98+
xi_i: strains for the segment
99+
Returns:
100+
A_d_segment: actuation matrix for the segment of shape (3, 3)
101+
"""
102102
sigma_norm = jnp.sqrt(xi_i[1] ** 2 + xi_i[2] ** 2)
103-
return jnp.array([
103+
A_d_wrt_xi_i = - jnp.array([
104104
d * l_i * sigma_norm,
105105
l_i * xi_i[1] * (1 + d * xi_i[0]) / sigma_norm,
106106
l_i * xi_i[2] * (1 + d * xi_i[0]) / sigma_norm,
107107
])
108-
A_d = vmap(compute_A_d)(l[:j+1], xi[: 3 * (j + 1)].reshape(-1, 3))
109-
108+
return jnp.where(
109+
i <= segment_idx,
110+
A_d_wrt_xi_i,
111+
jnp.zeros_like(A_d_wrt_xi_i)
112+
)
113+
114+
A_d = vmap(compute_A_d_wrt_xi_i)(segment_indices, l, xi.reshape(-1, 3))
115+
110116
# concatenate the derivatives for all segments
111117
A_d = jnp.concatenate(A_d, axis=0)
112-
A_sm.append(A_d)
113-
114-
# stack the actuation matrices for all tendons
115-
A_sm = jnp.stack(A_sm, axis=1)
116-
print("A_sm.shape", A_sm.shape)
118+
return A_d
119+
120+
A_sm = vmap(compute_A_d, in_axes=0, out_axes=1)(d_sm)
117121

118122
return A_sm
119123

120-
segment_indices = jnp.arange(num_segments)
121124
A_sms = vmap(compute_actuation_matrix_for_segment)(
122125
segment_indices, params["d"],
123126
)
124-
print("A_sms.shape", A_sms.shape)
125127
# concatenate the actuation matrices for all tendons
126128
A = jnp.concatenate(A_sms, axis=1)
127-
print("A.shape", A.shape)
128129

129130
# apply the actuation_basis
130131
A = A @ actuation_basis

0 commit comments

Comments
 (0)