Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ab00835
Add plot backup, full-JAX PCS planar + Gaussian quadrature integratio…
solangegbv May 12, 2025
231c9e3
Change fori_loop to vmap/scan and cond to min in
solangegbv May 13, 2025
19d91ea
Transformed simulate_planar_pcs into a function
solangegbv May 28, 2025
63ed97f
Replacement of at.set by block
solangegbv Jun 4, 2025
6dd50a4
Add quadax dependency, recover original files, remove jit for math_utils
solangegbv Jun 5, 2025
7c17c6a
Changing Coriolis for loop to vmap,
solangegbv Jun 6, 2025
3bb65cf
Corrected type error in planar_pcs_num.py
solangegbv Jun 10, 2025
7bb0d43
Added formulas in SE2, Coriolis corrections
solangegbv Jun 16, 2025
6952264
Correction of the kinetic energy function
solangegbv Jun 17, 2025
81e5895
Fix jit decorators to apply JIT-compilation to
solangegbv Jun 20, 2025
62a7aa4
Creation of a test file for planar_pcs_num.py
solangegbv Jul 7, 2025
d72094a
Roll-back changes to symbolic expressions
mstoelzle Jul 21, 2025
877634a
Bumpy version number and add Solange as an author
mstoelzle Jul 21, 2025
b4beba7
Merge branch 'main' into numerical-derivation
mstoelzle Jul 21, 2025
b8c1fe6
Rename `planar_pcs` system to `planar_pcs_sym`
mstoelzle Jul 21, 2025
5fec620
Fix some type hinting errors
mstoelzle Jul 21, 2025
6838cec
Fix missing changes in last commit
mstoelzle Jul 21, 2025
9baef51
Rename `test_planar_pcs.py` to `test_planar_pcs_sym`
mstoelzle Jul 21, 2025
3eb5391
Fix some bugs
mstoelzle Jul 21, 2025
94c4928
Format systems
mstoelzle Jul 21, 2025
44242a2
Format the `tests` files
mstoelzle Jul 21, 2025
e021a71
Format the `utils` files
mstoelzle Jul 21, 2025
7107cbd
Exclude some test scripts from automated testing if they require gui
mstoelzle Jul 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ dmypy.json

# DS_STORE
.DS_Store
*.mp4
*.gif
1,594 changes: 1,594 additions & 0 deletions examples/benchmark_planar_pcs.py

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions examples/simulate_planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def draw_robot(
tau = jnp.zeros_like(q0) # torques

ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
# jit the ODE function
ode_fn = jax.jit(ode_fn)
term = ODETerm(ode_fn)

sol = diffeqsolve(
Expand All @@ -145,10 +147,14 @@ def draw_robot(
# the evolution of the generalized velocities
q_d_ts = sol.ys[:, n_q:]

s_max = jnp.array([jnp.sum(params["l"])])

forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector)
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)

# evaluate the forward kinematics along the trajectory
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
params, q_ts, jnp.array([jnp.sum(params["l"])])
)
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
# plot the configuration vs time
plt.figure()
for segment_idx in range(num_segments):
Expand Down Expand Up @@ -202,10 +208,10 @@ def draw_robot(

# plot the energy along the trajectory
kinetic_energy_fn_vmapped = vmap(
partial(auxiliary_fns["kinetic_energy_fn"], params)
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
)
potential_energy_fn_vmapped = vmap(
partial(auxiliary_fns["potential_energy_fn"], params)
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)
)
U_ts = potential_energy_fn_vmapped(q_ts)
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
Expand Down
1 change: 0 additions & 1 deletion examples/videos/.gitignore

This file was deleted.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "jsrm" # Required
#
# For a discussion on single-sourcing the version, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version = "0.0.16" # Required
version = "0.0.17" # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down Expand Up @@ -57,7 +57,8 @@ keywords = ["JAX", "Soft Robotics", "Kinematics", "Dynamics"] # Optional
# authored the project, and a valid email address corresponding to the name
# listed.
authors = [
{name = "Maximilian Stölzle", email = "[email protected]" } # Optional
{name = "Maximilian Stölzle", email = "[email protected]" },
{name = "Solange Gribonval", email = "[email protected]" },
]

# This should be your name or the names of the organization who currently
Expand Down Expand Up @@ -105,6 +106,7 @@ dependencies = [ # Optional
"dill",
"jax",
"numpy",
"quadax",
"peppercorn",
"sympy>=1.11"
]
Expand Down
4 changes: 1 addition & 3 deletions src/jsrm/integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax import Array, jit
from jax import Array
from typing import Callable, Dict

from jsrm.systems import euler_lagrangian
Expand Down Expand Up @@ -26,7 +26,6 @@ def ode_factory(
ode_fn: ODE function of the form ode_fn(t, x) -> x_dot
"""

@jit
def ode_fn(t: float, x: Array, *args) -> Array:
"""
ODE of the dynamical Lagrangian system.
Expand Down Expand Up @@ -69,7 +68,6 @@ def ode_with_forcing_factory(
ode_fn: ODE function of the form ode_fn(t, x, tau) -> x_dot
"""

@jit
def ode_fn(
t: float,
x: Array,
Expand Down
42 changes: 36 additions & 6 deletions src/jsrm/math_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from jax import numpy as jnp
from jax import Array, lax, jit
from jax import Array, lax


@jit
def blk_diag(a: Array) -> Array:
def blk_diag(
a: Array
) -> Array:
"""
Create a block diagonal matrix from a tensor of blocks
Create a block diagonal matrix from a tensor of blocks.

Args:
a: matrices to be block diagonalized of shape (m, n, o)

Expand All @@ -31,7 +32,7 @@ def assign_block_diagonal(i, _b):

# Implement for loop to assign each block in `a` to the block-diagonal of `b`
# Hint: use `jax.lax.fori_loop` and pass `assign_block_diagonal` as an argument
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]))
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]), dtype=a.dtype)
b = lax.fori_loop(
lower=0,
upper=a.shape[0],
Expand All @@ -40,3 +41,32 @@ def assign_block_diagonal(i, _b):
)

return b

def blk_concat(
a: Array
) -> Array:
"""
Concatenate horizontally (along the columns) a list of N matrices of size (m, n) to create a single matrix of size (m, n * N).

Args:
a (Array): matrices to be concatenated of shape (N, m, n)

Returns:
b (Array): concatenated matrix of shape (m, N * n)
"""
b = a.transpose(1, 0, 2).reshape(a.shape[1], -1)
return b

if __name__ == "__main__":
# Example usage
a = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print("Original array:")
print(a)

b = blk_diag(a)
print("Block diagonal matrix:")
print(b)

c = blk_concat(a)
print("Concatenated matrix:")
print(c)
4 changes: 2 additions & 2 deletions src/jsrm/symbolic_derivation/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def symbolically_derive_planar_pcs_model(
# tendon length jacobians for each segment as a function of the point coordinate s
J_tend_sms = []
# cross-sectional area of each segment
A = sp.zeros(num_segments)
A = sp.zeros(num_segments, 1)
# second area moment of inertia of each segment
I = sp.zeros(num_segments)
I = sp.zeros(num_segments, 1)
# inertia matrix
B = sp.zeros(num_dof, num_dof)
# potential energy
Expand Down
1 change: 1 addition & 0 deletions src/jsrm/systems/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import jsrm.systems.planar_pcs_sym as planar_pcs
Loading
Loading