Skip to content

Commit 9104399

Browse files
committed
Add kinetic and potential energy functions for planar_pcs system
1 parent 6cd9e6a commit 9104399

File tree

7 files changed

+113
-7
lines changed

7 files changed

+113
-7
lines changed

examples/simulate_planar_pcs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import cv2 # importing cv2
2+
from functools import partial
23
import jax
34

45
jax.config.update("jax_enable_x64", True) # double precision
56
from diffrax import diffeqsolve, Dopri5, Euler, ODETerm, SaveAt
67
from jax import Array, vmap
78
from jax import numpy as jnp
9+
import matplotlib.pyplot as plt
810
import numpy as onp
911
from pathlib import Path
1012
from typing import Callable, Dict
@@ -42,6 +44,8 @@
4244

4345
# define initial configuration
4446
q0 = jnp.array([10 * jnp.pi])
47+
# number of generalized coordinates
48+
n_q = q0.shape[0]
4549

4650
# set simulation parameters
4751
dt = 1e-3 # time step
@@ -91,7 +95,7 @@ def draw_robot(
9195

9296

9397
if __name__ == "__main__":
94-
strain_basis, forward_kinematics_fn, dynamical_matrices_fn = planar_pcs.factory(
98+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = planar_pcs.factory(
9599
sym_exp_filepath, strain_selector
96100
)
97101
batched_forward_kinematics = vmap(
@@ -133,6 +137,25 @@ def draw_robot(
133137
)
134138

135139
print("sol.ys =\n", sol.ys)
140+
# the evolution of the generalized coordinates
141+
q_ts = sol.ys[:, :n_q]
142+
# the evolution of the generalized velocities
143+
q_d_ts = sol.ys[:, n_q:]
144+
145+
# plot the energy along the trajectory
146+
kinetic_energy_fn_vmapped = vmap(partial(auxiliary_fns["kinetic_energy_fn"], params))
147+
potential_energy_fn_vmapped = vmap(partial(auxiliary_fns["potential_energy_fn"], params))
148+
U_ts = potential_energy_fn_vmapped(q_ts)
149+
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
150+
plt.figure()
151+
plt.plot(video_ts, U_ts, label="Potential energy")
152+
plt.plot(video_ts, T_ts, label="Kinetic energy")
153+
plt.xlabel("Time [s]")
154+
plt.ylabel("Energy [J]")
155+
plt.legend()
156+
plt.grid(True)
157+
plt.box(True)
158+
plt.show()
136159

137160
# create video
138161
fourcc = cv2.VideoWriter_fourcc(*"MP4V")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ name = "jsrm" # Required
1717
#
1818
# For a discussion on single-sourcing the version, see
1919
# https://packaging.python.org/guides/single-sourcing-package-version/
20-
version = "0.0.7" # Required
20+
version = "0.0.8" # Required
2121

2222
# This is a one-line description or tagline of what your project does. This
2323
# corresponds to the "Summary" metadata field:

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def symbolically_derive_planar_pcs_model(
169169
), # expression for end-effector pose of shape (3, )
170170
"J_sms": J_sms,
171171
"Jee": J_sms[-1].subs(s, l[-1]),
172-
"B": B,
173-
"C": C,
174-
"G": G,
172+
"B": B, # mass matrix
173+
"C": C, # coriolis matrix
174+
"G": G, # gravity vector
175+
"U": U, # gravitational potential energy
175176
},
176177
}
177178

1.64 KB
Binary file not shown.
26.9 KB
Binary file not shown.

src/jsrm/systems/planar_pcs.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def factory(
2525
[Dict[str, Array], Array, Array],
2626
Tuple[Array, Array, Array, Array, Array, Array],
2727
],
28+
Dict[str, Callable],
2829
]:
2930
"""
3031
Create jax functions from file containing symbolic expressions.
@@ -38,6 +39,7 @@ def factory(
3839
B_xi: strain basis matrix of shape (3 * num_segments, n_q)
3940
forward_kinematics_fn: function that returns the p vector of shape (3, n_q) with the positions
4041
dynamical_matrices_fn: function that returns the B, C, G, K, D, and alpha matrices
42+
auxiliary_fns: dictionary with auxiliary functions
4143
"""
4244
# load saved symbolic data
4345
sym_exps = dill.load(open(str(filepath), "rb"))
@@ -249,5 +251,85 @@ def dynamical_matrices_fn(
249251
alpha = B_xi.T @ jnp.identity(n_xi) @ B_xi
250252

251253
return B, C, G, K, D, alpha
254+
252255

253-
return B_xi, forward_kinematics_fn, dynamical_matrices_fn
256+
def kinetic_energy_fn(params: Dict[str, Array], q: Array, q_d: Array) -> Array:
257+
"""
258+
Compute the kinetic energy of the system.
259+
Args:
260+
params: Dictionary of robot parameters
261+
q: generalized coordinates of shape (n_q, )
262+
q_d: generalized velocities of shape (n_q, )
263+
Returns:
264+
T: kinetic energy of shape ()
265+
"""
266+
B, C, G, K, D, alpha = dynamical_matrices_fn(params, q=q, q_d=q_d)
267+
268+
# kinetic energy
269+
T = (0.5 * q_d.T @ B @ q_d).squeeze()
270+
271+
return T
272+
273+
274+
def potential_energy_fn(params: Dict[str, Array], q: Array, eps: float = 1e4 * global_eps) -> Array:
275+
"""
276+
Compute the potential energy of the system.
277+
Args:
278+
params: Dictionary of robot parameters
279+
q: generalized coordinates of shape (n_q, )
280+
eps: small number to avoid singularities (e.g., division by zero)
281+
Returns:
282+
U: potential energy of shape ()
283+
"""
284+
# map the configuration to the strains
285+
xi = xi_eq + B_xi @ q
286+
# add a small number to the bending strain to avoid singularities
287+
xi_epsed = apply_eps_to_bend_strains(xi, eps)
288+
289+
# cross-sectional area and second moment of area
290+
A = jnp.pi * params["r"] ** 2
291+
Ib = A**2 / (4 * jnp.pi)
292+
293+
# elastic and shear modulus
294+
E, G = params["E"], params["G"]
295+
# stiffness matrix of shape (num_segments, 3, 3)
296+
S = compute_stiffness_matrix_for_all_segments_fn(A, Ib, E, G)
297+
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to
298+
K = blk_diag(S)
299+
# elastic energy
300+
U_K = (xi - xi_eq).T @ K @ (xi - xi_eq) # evaluate K(xi) = K @ xi
301+
302+
# gravitational potential energy
303+
U_G = sp.Matrix([[0]])
304+
params_for_lambdify = select_params_for_lambdify(params)
305+
U_G = G_lambda(*params_for_lambdify, *xi_epsed).squeeze() @ xi_epsed
306+
307+
# total potential energy
308+
U = (U_G + U_K).squeeze()
309+
310+
return U
311+
312+
def energy_fn(params: Dict[str, Array], q: Array, q_d: Array) -> Array:
313+
"""
314+
Compute the total energy of the system.
315+
Args:
316+
params: Dictionary of robot parameters
317+
q: generalized coordinates of shape (n_q, )
318+
q_d: generalized velocities of shape (n_q, )
319+
Returns:
320+
E: total energy of shape ()
321+
"""
322+
T = kinetic_energy_fn(params, q_d)
323+
U = potential_energy_fn(params, q)
324+
E = T + U
325+
326+
return E
327+
328+
auxiliary_fns = {
329+
"apply_eps_to_bend_strains": apply_eps_to_bend_strains,
330+
"kinetic_energy_fn": kinetic_energy_fn,
331+
"potential_energy_fn": potential_energy_fn,
332+
"energy_fn": energy_fn,
333+
}
334+
335+
return B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns

tests/test_planar_pcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_planar_pcs_one_segment():
2727
# activate all strains (i.e. bending, shear, and axial)
2828
strain_selector = jnp.ones((3,), dtype=bool)
2929

30-
strain_basis, forward_kinematics_fn, dynamical_matrices_fn = planar_pcs.factory(
30+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = planar_pcs.factory(
3131
sym_exp_filepath, strain_selector
3232
)
3333
forward_dynamics_fn = partial(

0 commit comments

Comments
 (0)