Skip to content

Commit 839bf50

Browse files
committed
Fully implement pneumatic actuation model
1 parent 8d4a7b0 commit 839bf50

File tree

3 files changed

+113
-29
lines changed

3 files changed

+113
-29
lines changed

examples/simulate_pneumatic_planar_pcs.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"r_cham_out": 2e-2 - 2e-3 * jnp.ones((num_segments,)),
3838
"varphi_cham": jnp.pi/2 * jnp.ones((num_segments,)),
3939
}
40-
params["D"] = 1e-3 * jnp.diag(
40+
params["D"] = 5e-4 * jnp.diag(
4141
(jnp.repeat(
4242
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
4343
) * params["l"][:, None]).flatten()
@@ -52,29 +52,88 @@
5252
)
5353
# jit the functions
5454
dynamical_matrices_fn = jax.jit(dynamical_matrices_fn)
55-
actuation_mapping_fn = auxiliary_fns["actuation_mapping_fn"]
55+
actuation_mapping_fn = partial(
56+
auxiliary_fns["actuation_mapping_fn"],
57+
forward_kinematics_fn,
58+
auxiliary_fns["jacobian_fn"],
59+
)
5660

5761
def sweep_actuation_mapping():
62+
# evaluate the actuation matrix for a straight backbone
5863
q = jnp.zeros((2 * num_segments,))
5964
A = actuation_mapping_fn(params, B_xi, q)
60-
print("A =\n", A)
65+
print("Evaluating actuation matrix for straight backbone: A =\n", A)
66+
67+
kappa_be_pts = jnp.linspace(-jnp.pi, jnp.pi, 500)
68+
sigma_ax_pts = jnp.zeros_like(kappa_be_pts)
69+
q_pts = jnp.stack([kappa_be_pts, sigma_ax_pts], axis=-1)
70+
A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts)
71+
# plot the mapping on the bending strain for various bending strains
72+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_bending_strain")
73+
plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{be}$")
74+
ax.plot(kappa_be_pts, A_pts[:, 0, 0], linewidth=2)
75+
# shade the region where the actuation mapping is negative as we are not able to bend the robot further
76+
ax.axhspan(A_pts[:, 0, 0].min(), 0.0, facecolor='red', alpha=0.5)
77+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
78+
ax.set_ylabel(r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$")
79+
plt.grid(True)
80+
plt.tight_layout()
81+
plt.show()
82+
83+
# create grid for bending and axial strains
84+
kappa_be_grid, sigma_ax_grid = jnp.meshgrid(
85+
jnp.linspace(-jnp.pi, jnp.pi, 20),
86+
jnp.linspace(-0.2, 0.2, 20),
87+
)
88+
q_pts = jnp.stack([kappa_be_grid.flatten(), sigma_ax_grid.flatten()], axis=-1)
89+
90+
# evaluate the actuation mapping on the grid
91+
A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts)
92+
# reshape A_pts to match the grid shape
93+
A_grid = A_pts.reshape(kappa_be_grid.shape[:2] + A_pts.shape[-2:])
94+
95+
# plot the mapping on the bending strain
96+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_axial_vs_bending_strain")
97+
plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{be}$")
98+
# contourf plot
99+
c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=100)
100+
fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$")
101+
# contour plot
102+
ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=20, colors="k", linewidths=0.5)
103+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
104+
ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]")
105+
plt.tight_layout()
106+
plt.show()
107+
108+
# plot the mapping on the axial strain
109+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_axial_torque_vs_axial_vs_bending_strain")
110+
plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{ax}$")
111+
# contourf plot
112+
c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=100)
113+
fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{ax}}{\partial u_1}$")
114+
# contour plot
115+
ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=20, colors="k", linewidths=0.5)
116+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
117+
ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]")
118+
plt.tight_layout()
119+
plt.show()
61120

62121

63122
def simulate_robot():
64123
# define initial configuration
65-
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.2])[None, :], num_segments, axis=0).flatten()
124+
q0 = jnp.repeat(jnp.array([-5.0 * jnp.pi, -0.2])[None, :], num_segments, axis=0).flatten()
66125
# number of generalized coordinates
67126
n_q = q0.shape[0]
68127

69128
# set simulation parameters
70129
dt = 1e-3 # time step
71130
sim_dt = 5e-5 # simulation time step
72-
ts = jnp.arange(0.0, 2, dt) # time steps
131+
ts = jnp.arange(0.0, 7.0, dt) # time steps
73132

74133
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
75-
tau = jnp.zeros_like(q0) # torques
134+
u = jnp.array([1.2e3, 0e0]) # control inputs (pressures in the right and left chambers)
76135

77-
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
136+
ode_fn = ode_factory(dynamical_matrices_fn, params, u)
78137
term = ODETerm(ode_fn)
79138

80139
sol = diffeqsolve(

src/jsrm/systems/planar_pcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def actuation_mapping_fn(
256256
Returns:
257257
A: actuation matrix of shape (n_xi, n_xi) where n_xi is the number of strains.
258258
"""
259-
A = jnp.identity(n_xi) @ B_xi
259+
A = B_xi.T @ jnp.identity(n_xi) @ B_xi
260260

261261
return A
262262

@@ -359,7 +359,7 @@ def dynamical_matrices_fn(
359359
# compute the stiffness matrix
360360
K = stiffness_fn(params, B_xi, formulate_in_strain_space=True)
361361
# compute the actuation matrix
362-
A = actuation_mapping_fn(forward_kinematics_fn, actuation_mapping_fn, params, B_xi, q)
362+
A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, q)
363363

364364
# dissipative matrix from the parameters
365365
D = params.get("D", jnp.zeros((n_xi, n_xi)))
@@ -376,7 +376,7 @@ def dynamical_matrices_fn(
376376
D = B_xi.T @ D @ B_xi
377377

378378
# apply the strain basis to the actuation matrix
379-
alpha = B_xi.T @ A
379+
alpha = A
380380

381381
return B, C, G, K, D, alpha
382382

src/jsrm/systems/pneumatic_planar_pcs.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,8 @@ def actuation_mapping_fn(
5555
A: actuation matrix of shape (n_xi, n_act) where n_xi is the number of strains and
5656
n_act is the number of actuators
5757
"""
58-
# map the configurations to strains
59-
xi = B_xi @ q
60-
61-
# number of strains
62-
n_xi = xi.shape[0]
63-
6458
# all segment bases and tips
6559
sms = jnp.concat([jnp.zeros((1,)), jnp.cumsum(params["l"])], axis=0)
66-
print("sms =\n", sms)
6760

6861
# compute the poses of all segment tips
6962
chi_sms = vmap(forward_kinematics_fn, in_axes=(None, None, 0))(params, q, sms)
@@ -74,10 +67,20 @@ def actuation_mapping_fn(
7467
def compute_actuation_matrix_for_segment(
7568
r_cham_in: Array, r_cham_out: Array, varphi_cham: Array,
7669
chi_pe: Array, chi_de: Array,
77-
J_pe: Array, J_de: Array, xi: Array
70+
J_pe: Array, J_de: Array,
7871
) -> Array:
7972
"""
8073
Compute the actuation matrix for a single segment.
74+
We assume that each segment contains four identical and symmetric pneumatic chambers with pressures
75+
p1, p2, p3, and p4, where p1 and p3 are the right and left chamber pressures respectively, and
76+
p2 and p4 are the back and front chamber pressures respectively. The front and back chambers
77+
do not exert a level arm (i.e., a bending moment) on the segment.
78+
We map the control inputs u1 and u2 as follows to the pressures:
79+
p1 = u1 (right chamber)
80+
p2 = (u1 + u2) / 2
81+
p3 = u2 (left chamber)
82+
p4 = (u1 + u2) / 2
83+
8184
Args:
8285
r_cham_in: inner radius of each segment chamber
8386
r_cham_out: outer radius of each segment chamber
@@ -86,23 +89,45 @@ def compute_actuation_matrix_for_segment(
8689
chi_de: pose of the distal end (i.e., the tip) of the segment as array of shape (3,)
8790
J_pe: Jacobian of the proximal end of the segment as array of shape (3, n_q)
8891
J_de: Jacobian of the distal end of the segment as array of shape (3, n_q)
89-
xi: strains of the segment
9092
Returns:
9193
A_sm: actuation matrix of shape (n_xi, 2)
9294
"""
93-
# rotation matrix from the robot base to the segment base
94-
R_pe = jnp.array([[jnp.cos(chi_pe[2]), -jnp.sin(chi_pe[2])], [jnp.sin(chi_pe[2]), jnp.cos(chi_pe[2])]])
95-
# rotation matrix from the robot base to the segment tip
96-
R_de = jnp.array([[jnp.cos(chi_de[2]), -jnp.sin(chi_de[2])], [jnp.sin(chi_de[2]), jnp.cos(chi_de[2])]])
95+
# orientation of the proximal and distal ends of the segment
96+
th_pe, th_de = chi_pe[2], chi_de[2]
97+
98+
# compute the area of each pneumatic chamber (we assume identical chambers within a segment)
99+
A_cham = 0.5 * varphi_cham * (r_cham_out ** 2 - r_cham_in ** 2)
100+
# compute the center of pressure of the pneumatic chamber
101+
r_cop = (
102+
2 / 3 * jnp.sinc(0.5 * varphi_cham) * (r_cham_out ** 3 - r_cham_in ** 3) / (r_cham_out ** 2 - r_cham_in ** 2)
103+
)
104+
105+
# compute the actuation matrix that collects the contributions of the pneumatic chambers in the given segment
106+
# first we consider the contribution of the distal end
107+
A_sm_de = J_de.T @ jnp.array([
108+
[-2 * A_cham * jnp.sin(th_de), -2 * A_cham * jnp.sin(th_de)],
109+
[2 * A_cham * jnp.cos(th_de), 2 * A_cham * jnp.cos(th_de)],
110+
[A_cham * r_cop, -A_cham * r_cop]
111+
])
112+
# then, we consider the contribution of the proximal end
113+
A_sm_pe = J_pe.T @ jnp.array([
114+
[2 * A_cham * jnp.sin(th_pe), 2 * A_cham * jnp.sin(th_pe)],
115+
[-2 * A_cham * jnp.cos(th_pe), -2 * A_cham * jnp.cos(th_pe)],
116+
[-A_cham * r_cop, A_cham * r_cop]
117+
])
118+
119+
# sum the contributions of the distal and proximal ends
120+
A_sm = A_sm_de + A_sm_pe
97121

98-
99-
# compute the actuation matrix for a single segment
100-
A_sm = jnp.zeros((n_xi, 2))
101122
return A_sm
102123

103-
A_sms = vmap(compute_actuation_matrix_for_segment)(chi_sms, J_sms, xi)
104-
105-
A = jnp.zeros((n_xi, 2 * num_segments))
124+
A_sms = vmap(compute_actuation_matrix_for_segment)(
125+
params["r_cham_in"], params["r_cham_out"], params["varphi_cham"],
126+
chi_pe=chi_sms[:-1], chi_de=chi_sms[1:],
127+
J_pe=J_sms[:-1], J_de=J_sms[1:],
128+
)
129+
# we need to sum the contributions of the actuation of each segment
130+
A = jnp.sum(A_sms, axis=0)
106131

107132
# apply the actuation_basis
108133
A = A @ actuation_basis

0 commit comments

Comments
 (0)