Skip to content

Commit b7944dc

Browse files
Add plot backup, full-JAX PCS planar + Gaussian quadrature integratio… (#6)
* Add plot backup, full-JAX PCS planar + Gaussian quadrature integration scheme * Change fori_loop to vmap/scan and cond to min in the planar_pcs_num.py file. Correct a description of function in the planar_pcs.py file. * Transformed simulate_planar_pcs into a function callable from other files for comparison purposes with the users choice parameters : - option of saving or not the results, figures, videos - option of plotting/printing or not the results, figures - option on the type of derivation to use : symbolic, numeric - option on the type of integration and parameter of integration to use : gauss, trapezoid - option on the type of jacobian to use : explicit or autodifferentiation Added the ability to save simulation results in pickle files (.pkl) for later comparison Set up an explicit Jacobian in SE3 to compute B and G - SE(3) Lie algebra operators - convert SE(2) to SE(3) to use operators for the planar case * Replacement of at.set by block Implementation of a Gauss-Kronrad quadrature integration using the quadax library Function documentation * Add quadax dependency, recover original files, remove jit for math_utils * Changing Coriolis for loop to vmap, get rid of jnp.array when possible, benchmark and tests on eps. * Corrected type error in planar_pcs_num.py * Added formulas in SE2, Coriolis corrections Autodiff: for loop calculation corrected Explicit: implementation of explicit calculation using Lie algebra Various tests * Correction of the kinetic energy function kinetic energy depends only on B and does not need to calculate other dynamic matrices Correction of documentations * Fix jit decorators to apply JIT-compilation to last level * Creation of a test file for planar_pcs_num.py Corrected documentation Removal of unnecessary imports Ready to merge * Roll-back changes to symbolic expressions * Bumpy version number and add Solange as an author * Rename `planar_pcs` system to `planar_pcs_sym` * Fix some type hinting errors * Fix missing changes in last commit * Rename `test_planar_pcs.py` to `test_planar_pcs_sym` * Fix some bugs * Format systems * Format the `tests` files * Format the `utils` files * Exclude some test scripts from automated testing if they require gui --------- Co-authored-by: Maximilian Stölzle <[email protected]>
1 parent 7b6590b commit b7944dc

21 files changed

+5281
-244
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,5 @@ dmypy.json
135135

136136
# DS_STORE
137137
.DS_Store
138+
*.mp4
139+
*.gif

examples/benchmark_planar_pcs.py

Lines changed: 1594 additions & 0 deletions
Large diffs are not rendered by default.

examples/simulate_planar_pcs.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def draw_robot(
126126
tau = jnp.zeros_like(q0) # torques
127127

128128
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
129+
# jit the ODE function
130+
ode_fn = jax.jit(ode_fn)
129131
term = ODETerm(ode_fn)
130132

131133
sol = diffeqsolve(
@@ -145,10 +147,14 @@ def draw_robot(
145147
# the evolution of the generalized velocities
146148
q_d_ts = sol.ys[:, n_q:]
147149

150+
s_max = jnp.array([jnp.sum(params["l"])])
151+
152+
forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
153+
forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector)
154+
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)
155+
148156
# evaluate the forward kinematics along the trajectory
149-
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
150-
params, q_ts, jnp.array([jnp.sum(params["l"])])
151-
)
157+
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
152158
# plot the configuration vs time
153159
plt.figure()
154160
for segment_idx in range(num_segments):
@@ -202,10 +208,10 @@ def draw_robot(
202208

203209
# plot the energy along the trajectory
204210
kinetic_energy_fn_vmapped = vmap(
205-
partial(auxiliary_fns["kinetic_energy_fn"], params)
211+
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
206212
)
207213
potential_energy_fn_vmapped = vmap(
208-
partial(auxiliary_fns["potential_energy_fn"], params)
214+
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)
209215
)
210216
U_ts = potential_energy_fn_vmapped(q_ts)
211217
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)

examples/videos/.gitignore

Lines changed: 0 additions & 1 deletion
This file was deleted.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ name = "jsrm" # Required
1717
#
1818
# For a discussion on single-sourcing the version, see
1919
# https://packaging.python.org/guides/single-sourcing-package-version/
20-
version = "0.0.16" # Required
20+
version = "0.0.17" # Required
2121

2222
# This is a one-line description or tagline of what your project does. This
2323
# corresponds to the "Summary" metadata field:
@@ -57,7 +57,8 @@ keywords = ["JAX", "Soft Robotics", "Kinematics", "Dynamics"] # Optional
5757
# authored the project, and a valid email address corresponding to the name
5858
# listed.
5959
authors = [
60-
{name = "Maximilian Stölzle", email = "[email protected]" } # Optional
60+
{name = "Maximilian Stölzle", email = "[email protected]" },
61+
{name = "Solange Gribonval", email = "[email protected]" },
6162
]
6263

6364
# This should be your name or the names of the organization who currently
@@ -105,6 +106,7 @@ dependencies = [ # Optional
105106
"dill",
106107
"jax",
107108
"numpy",
109+
"quadax",
108110
"peppercorn",
109111
"sympy>=1.11"
110112
]

src/jsrm/integration.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax import Array, jit
1+
from jax import Array
22
from typing import Callable, Dict
33

44
from jsrm.systems import euler_lagrangian
@@ -26,7 +26,6 @@ def ode_factory(
2626
ode_fn: ODE function of the form ode_fn(t, x) -> x_dot
2727
"""
2828

29-
@jit
3029
def ode_fn(t: float, x: Array, *args) -> Array:
3130
"""
3231
ODE of the dynamical Lagrangian system.
@@ -69,7 +68,6 @@ def ode_with_forcing_factory(
6968
ode_fn: ODE function of the form ode_fn(t, x, tau) -> x_dot
7069
"""
7170

72-
@jit
7371
def ode_fn(
7472
t: float,
7573
x: Array,

src/jsrm/math_utils.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from jax import numpy as jnp
2-
from jax import Array, lax, jit
2+
from jax import Array, lax
33

4-
5-
@jit
6-
def blk_diag(a: Array) -> Array:
4+
def blk_diag(
5+
a: Array
6+
) -> Array:
77
"""
8-
Create a block diagonal matrix from a tensor of blocks
8+
Create a block diagonal matrix from a tensor of blocks.
9+
910
Args:
1011
a: matrices to be block diagonalized of shape (m, n, o)
1112
@@ -31,7 +32,7 @@ def assign_block_diagonal(i, _b):
3132

3233
# Implement for loop to assign each block in `a` to the block-diagonal of `b`
3334
# Hint: use `jax.lax.fori_loop` and pass `assign_block_diagonal` as an argument
34-
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]))
35+
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]), dtype=a.dtype)
3536
b = lax.fori_loop(
3637
lower=0,
3738
upper=a.shape[0],
@@ -40,3 +41,32 @@ def assign_block_diagonal(i, _b):
4041
)
4142

4243
return b
44+
45+
def blk_concat(
46+
a: Array
47+
) -> Array:
48+
"""
49+
Concatenate horizontally (along the columns) a list of N matrices of size (m, n) to create a single matrix of size (m, n * N).
50+
51+
Args:
52+
a (Array): matrices to be concatenated of shape (N, m, n)
53+
54+
Returns:
55+
b (Array): concatenated matrix of shape (m, N * n)
56+
"""
57+
b = a.transpose(1, 0, 2).reshape(a.shape[1], -1)
58+
return b
59+
60+
if __name__ == "__main__":
61+
# Example usage
62+
a = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
63+
print("Original array:")
64+
print(a)
65+
66+
b = blk_diag(a)
67+
print("Block diagonal matrix:")
68+
print(b)
69+
70+
c = blk_concat(a)
71+
print("Concatenated matrix:")
72+
print(c)

src/jsrm/symbolic_derivation/planar_pcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def symbolically_derive_planar_pcs_model(
6565
# tendon length jacobians for each segment as a function of the point coordinate s
6666
J_tend_sms = []
6767
# cross-sectional area of each segment
68-
A = sp.zeros(num_segments)
68+
A = sp.zeros(num_segments, 1)
6969
# second area moment of inertia of each segment
70-
I = sp.zeros(num_segments)
70+
I = sp.zeros(num_segments, 1)
7171
# inertia matrix
7272
B = sp.zeros(num_dof, num_dof)
7373
# potential energy

src/jsrm/systems/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import jsrm.systems.planar_pcs_sym as planar_pcs

0 commit comments

Comments
 (0)