Skip to content

Commit 6dd50a4

Browse files
committed
Add quadax dependency, recover original files, remove jit for math_utils
1 parent 63ed97f commit 6dd50a4

File tree

8 files changed

+1790
-839
lines changed

8 files changed

+1790
-839
lines changed

examples/benchmark_planar_pcs_num.py

Lines changed: 1415 additions & 0 deletions
Large diffs are not rendered by default.

examples/simulate_planar_pcs.py

Lines changed: 214 additions & 776 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ dependencies = [ # Optional
105105
"dill",
106106
"jax",
107107
"numpy",
108+
"quadax",
108109
"peppercorn",
109110
"sympy>=1.11"
110111
]

src/jsrm/math_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from jax import numpy as jnp
22
from jax import Array, lax, jit
33

4-
@jit
54
def blk_diag(
65
a: Array
76
) -> Array:
@@ -43,18 +42,17 @@ def assign_block_diagonal(i, _b):
4342

4443
return b
4544

46-
@jit
4745
def blk_concat(
4846
a: Array
4947
) -> Array:
5048
"""
51-
Concatenate horizontally (along the columns) a list of N matrices of size (a, b) to create a single matrix of size (a, b * N).
49+
Concatenate horizontally (along the columns) a list of N matrices of size (m, n) to create a single matrix of size (m, n * N).
5250
5351
Args:
54-
a (Array): matrices to be concatenated of shape (N, a, b)
52+
a (Array): matrices to be concatenated of shape (N, m, n)
5553
5654
Returns:
57-
Array: concatenated matrix of shape (a, N * b)
55+
b (Array): concatenated matrix of shape (m, N * n)
5856
"""
5957
b = a.transpose(1, 0, 2).reshape(a.shape[1], -1)
6058
return b
224 Bytes
Binary file not shown.
545 Bytes
Binary file not shown.

src/jsrm/systems/planar_pcs_num.py

Lines changed: 140 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def classify_segment(
281281
return segment_idx, s_segment.squeeze(), l_cum
282282

283283
@jit
284-
def chi_fn_xi(
284+
def chi_fn(
285285
params: Dict[str, Array],
286286
xi: Array,
287287
s: Array
@@ -297,13 +297,13 @@ def chi_fn_xi(
297297
Returns:
298298
Array: pose of the robot at the point s in the interval [0, L].
299299
"""
300-
th0 = jnp.array(params["th0"]) # initial angle of the robot
300+
th0 = params["th0"] # initial angle of the robot
301301
l = params["l"] # length of each segment [m]
302302

303303
# Classify the point along the robot to the corresponding segment
304304
segment_idx, s_local, _ = classify_segment(params, s)
305305

306-
chi_O = jnp.array([0.0, 0.0, th0]) # Initial pose of the robot
306+
chi_0 = jnp.array([0, 0, th0]) # Initial pose of the robot #TODO
307307

308308
# Iteration function
309309
def chi_i(
@@ -340,7 +340,7 @@ def chi_i(
340340

341341
chi, chi_list = lax.scan(
342342
f = lambda carry, i: (chi_i(i, carry), chi_i(i, carry)),
343-
init = chi_O,
343+
init = chi_0,
344344
xs = jnp.arange(num_segments + 1))
345345

346346
return chi_list[segment_idx]
@@ -372,7 +372,7 @@ def forward_kinematics_fn(
372372
# Add a small number to the bending strain to avoid singularities
373373
xi = apply_eps_to_bend_strains(xi, eps)
374374

375-
chi = chi_fn_xi(params, xi, s)
375+
chi = chi_fn(params, xi, s)
376376

377377
return chi
378378

@@ -399,7 +399,7 @@ def J_autodiff(
399399
xi = apply_eps_to_bend_strains(xi, eps)
400400

401401
# Compute the Jacobian of chi_fn with respect to xi
402-
J = jacobian(lambda xi: chi_fn_xi(params, xi, s))(xi)
402+
J = jacobian(lambda _xi: chi_fn(params, _xi, s))(xi)
403403

404404
# apply the strain basis to the Jacobian
405405
J = J @ B_xi
@@ -507,7 +507,7 @@ def J_i(
507507
# From local to global frame : applying the rotation of the pose at point s
508508

509509
# Get the pose at point s
510-
px, py, theta = chi_fn_xi(params, xi, s)
510+
px, py, theta = chi_fn(params, xi, s)
511511
# Convert the pose to SE(3) representation
512512
R = jnp.array([ # Rotation matrix around the z-axis
513513
[jnp.cos(theta), -jnp.sin(theta)],
@@ -560,7 +560,7 @@ def J_Jd_autodiff(
560560
xi = apply_eps_to_bend_strains(xi, eps)
561561

562562
# Compute the Jacobian of chi_fn with respect to xi
563-
J = jacobian(lambda xi: chi_fn_xi(params, xi, s))(xi)
563+
J = jacobian(lambda _xi: chi_fn(params, _xi, s))(xi)
564564

565565
dJ_dxi = jacobian(J)(xi)
566566
J_d = jnp.tensordot(dJ_dxi, xi_d, axes=([2], [0]))
@@ -683,7 +683,7 @@ def J_i(
683683
# From local to global frame : applying the rotation of the pose at point s
684684

685685
# Get the pose at point s
686-
px, py, theta = chi_fn_xi(params, xi, s)
686+
px, py, theta = chi_fn(params, xi, s)
687687
# Convert the pose to SE(3) representation
688688
g_s = jnp.eye(4) # Initialize as identity matrix
689689
R = jnp.array([ # Rotation matrix around the z-axis
@@ -704,13 +704,13 @@ def J_i(
704704
lambda i: lax.dynamic_slice(xi_d, (3 * i,), (3,))
705705
)(idx_range) # shape: (num_segments, 3)
706706
xi_d_SE3_i = vmap(
707-
lambda xi_d_i : vec_SE2_to_xi_SE3(xi_d_i, SE2_to_SE3_indices)
707+
lambda _xi_d_i: vec_SE2_to_xi_SE3(_xi_d_i, SE2_to_SE3_indices)
708708
)(xi_d_i) # shape: (num_segments, 6)
709709
S_i = vmap(
710-
lambda i:lax.dynamic_index_in_dim(J_segment_SE3_global, i, axis=0, keepdims=False)
710+
lambda i: lax.dynamic_index_in_dim(J_segment_SE3_global, i, axis=0, keepdims=False)
711711
)(idx_range) # shape: (num_segments, 6, 6)
712712
sum_Sj_xi_d_j = vmap(
713-
lambda i, xi_d_SE3_i: compute_weighted_sums(J_segment_SE3_global, xi_d_SE3_i, i)
713+
lambda i, _xi_d_SE3_i: compute_weighted_sums(J_segment_SE3_global, _xi_d_SE3_i, i)
714714
)(idx_range, xi_d_SE3_i) # shape: (num_segments, 6)
715715
adjoint_sum = vmap(adjoint_SE3)(sum_Sj_xi_d_j) # shape: (num_segments, 6, 6)
716716

@@ -728,7 +728,7 @@ def J_i(
728728
# From local to global frame : applying the rotation of the pose at point s
729729

730730
# Get the pose at point s
731-
px, py, theta = chi_fn_xi(params, xi, s)
731+
px, py, theta = chi_fn(params, xi, s)
732732
# Convert the pose to SE(3) representation
733733
g_s = jnp.eye(4) # Initialize as identity matrix
734734
R = jnp.array([ # Rotation matrix around the z-axis
@@ -874,7 +874,7 @@ def integrand(s):
874874
return B
875875

876876
@jit
877-
def B_C_fn_xi(
877+
def B_C_fn(
878878
params: Dict[str, Array],
879879
xi: Array,
880880
xi_d: Array
@@ -888,31 +888,125 @@ def B_C_fn_xi(
888888
xi_d (Array): velocity vector of the robot.
889889
890890
Returns:
891-
Tuple[Array, Array]:
892-
- B (Array): mass / inertia matrix of the robot.
893-
- C (Array): Coriolis / centrifugal matrix of the robot.
891+
B (Array): mass / inertia matrix of the robot.
892+
C (Array): Coriolis / centrifugal matrix of the robot.
894893
"""
895894

896895
# Compute the mass / inertia matrix
897896
B = B_fn_xi(params, xi)
898897

899-
# Compute the Christoffel symbols
900-
def christoffel_symbol(i, j, k):
898+
# # Compute the Christoffel symbols
899+
# def christoffel_symbol(i, j, k):
900+
# return 0.5 * (
901+
# grad(lambda x: B[i, j])(xi)[k]
902+
# + grad(lambda x: B[i, k])(xi)[j]
903+
# - grad(lambda x: B[j, k])(xi)[i]
904+
# )
905+
906+
# # Compute the Coriolis / centrifugal matrix
907+
# C = jnp.zeros_like(B)
908+
# for i in range(B.shape[0]):
909+
# for j in range(B.shape[1]):
910+
# for k in range(B.shape[0]):
911+
# C = C.at[i, j].add(christoffel_symbol(i, j, k) * xi_d[k])
912+
913+
n = B.shape[0] # number of segments
914+
915+
# Compute Christoffel symbols for all (i,j,k)
916+
def christoffel_fn(i, j, k):
901917
return 0.5 * (
902-
grad(lambda x: B[i, j])(xi)[k]
903-
+ grad(lambda x: B[i, k])(xi)[j]
904-
- grad(lambda x: B[j, k])(xi)[i]
918+
grad(lambda x: B_fn_xi(params, x)[i, j])(xi)[k]
919+
+ grad(lambda x: B_fn_xi(params, x)[i, k])(xi)[j]
920+
- grad(lambda x: B_fn_xi(params, x)[j, k])(xi)[i]
905921
)
922+
923+
# Vectorize over k
924+
def C_ij(i, j):
925+
cs_k = vmap(lambda k: christoffel_fn(i, j, k))(jnp.arange(n)) # shape (n,)
926+
return jnp.dot(cs_k, xi_d)
906927

907-
# Compute the Coriolis / centrifugal matrix
908-
C = jnp.zeros_like(B)
909-
for i in range(B.shape[0]):
910-
for j in range(B.shape[1]):
911-
for k in range(B.shape[0]):
912-
C = C.at[i, j].add(christoffel_symbol(i, j, k) * xi_d[k])
913-
928+
# Vectorize over i and j
929+
C = vmap(lambda i: vmap(lambda j: C_ij(i, j))(jnp.arange(n)))(jnp.arange(n)) # shape (n, n)
930+
914931
return B, C
915932

933+
# @jit
934+
# def B_C_fn_explicit(
935+
# params: Dict[str, Array],
936+
# xi: Array
937+
# ) -> Array:
938+
# """
939+
# Compute the mass / inertia matrix of the robot.
940+
941+
# Args:
942+
# params (Dict[str, Array]): dictionary of robot parameters.
943+
# xi (Array): strain vector of the robot.
944+
945+
# Returns:
946+
# Array: mass / inertia matrix of the robot.
947+
# """
948+
# # Extract the parameters
949+
# rho = params["rho"] # density of each segment [kg/m^3]
950+
# l = params["l"] # length of each segment [m]
951+
# r = params["r"] # radius of each segment [m]
952+
953+
# # Usefull derived quantities
954+
# A = jnp.pi * r**2 # cross-sectional area of each segment [m^2]
955+
# Ib = A**2 / (4 * jnp.pi) # second moment of area of each segment [m^4]
956+
957+
# l_cum = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), l]))
958+
# # Compute each integral
959+
# def compute_integral(i):
960+
# if integration_type == "gauss-legendre":
961+
# Xs, Ws, nGauss = gauss_quadrature(N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1])
962+
963+
# J_all = vmap(lambda s: jacobian_fn_xi(params, xi, s))(Xs)
964+
# Jp_all = J_all[:, :2, :]
965+
# Jo_all = J_all[:, 2:, :]
966+
967+
# integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all)
968+
# integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all)
969+
970+
# integral_Jp = jnp.sum(Ws[:, None, None] * integrand_JpT_Jp, axis=0)
971+
# integral_Jo = jnp.sum(Ws[:, None, None] * integrand_JoT_Jo, axis=0)
972+
973+
# integral_B = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo
974+
975+
# elif integration_type == "gauss-kronrad":
976+
# rule = GaussKronrodRule(order=param_integration)
977+
# def integrand(s):
978+
# J = jacobian_fn_xi(params, xi, s)
979+
# Jp = J[:2, :]
980+
# Jo = J[2:, :]
981+
# return rho[i] * A[i] * Jp.T @ Jp + rho[i] * Ib[i] * Jo.T @ Jo
982+
983+
# integral_B, _, _, _ = rule.integrate(integrand, l_cum[i], l_cum[i+1], args=())
984+
985+
# elif integration_type == "trapezoid":
986+
# xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration)
987+
988+
# J_all = vmap(lambda s: jacobian_fn_xi(params, xi, s))(xs)
989+
# Jp_all = J_all[:, :2, :]
990+
# Jo_all = J_all[:, 2:, :]
991+
992+
# integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all)
993+
# integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all)
994+
995+
# integral_Jp = jscipy.integrate.trapezoid(integrand_JpT_Jp, x=xs, axis=0)
996+
# integral_Jo = jscipy.integrate.trapezoid(integrand_JoT_Jo, x=xs, axis=0)
997+
998+
# integral_B = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo
999+
1000+
# return integral_B
1001+
1002+
# # Compute the cumulative integral
1003+
# indices = jnp.arange(num_segments)
1004+
# integrals = vmap(compute_integral)(indices)
1005+
1006+
# B = jnp.sum(integrals, axis=0)
1007+
1008+
# return B
1009+
9161010
@jit
9171011
def U_g_fn_xi(
9181012
params: Dict[str, Array],
@@ -928,7 +1022,7 @@ def U_g_fn_xi(
9281022
eps (float, optional): small number to avoid singularities. Defaults to global_eps.
9291023
9301024
Returns:
931-
Array: gravity vector of the robot.
1025+
U_g (Array): gravity vector of the robot.
9321026
"""
9331027
# Add a small number to the bending strain to avoid singularities
9341028
xi = apply_eps_to_bend_strains(xi, eps)
@@ -947,7 +1041,7 @@ def U_g_fn_xi(
9471041
def compute_integral(i):
9481042
if integration_type == "gauss-legendre":
9491043
Xs, Ws, nGauss = gauss_quadrature(N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1])
950-
chi_s = vmap(lambda s: chi_fn_xi(params, xi, s))(Xs)
1044+
chi_s = vmap(lambda s: chi_fn(params, xi, s))(Xs)
9511045
p_s = chi_s[:, :2]
9521046
integrand = -rho[i] * A[i] * jnp.einsum("ij,j->i", p_s, g)
9531047

@@ -957,7 +1051,7 @@ def compute_integral(i):
9571051
elif integration_type == "gauss-kronrad":
9581052
rule = GaussKronrodRule(order=param_integration)
9591053
def integrand(s):
960-
chi_s = chi_fn_xi(params, xi, s)
1054+
chi_s = chi_fn(params, xi, s)
9611055
p_s = chi_s[:2]
9621056
return -rho[i] * A[i] * jnp.dot(p_s, g)
9631057

@@ -966,7 +1060,7 @@ def integrand(s):
9661060

9671061
elif integration_type == "trapezoid":
9681062
xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration)
969-
chi_s = vmap(lambda s: chi_fn_xi(params, xi, s))(xs)
1063+
chi_s = vmap(lambda s: chi_fn(params, xi, s))(xs)
9701064
p_s = chi_s[:, :2]
9711065
integrand = -rho[i] * A[i] * jnp.einsum("ij,j->i", p_s, g)
9721066

@@ -984,7 +1078,7 @@ def integrand(s):
9841078
return U_g
9851079

9861080
@jit
987-
def G_fn_xi_autodiff(
1081+
def G_autodiff_fn(
9881082
params: Dict[str, Array],
9891083
xi: Array
9901084
) -> Array:
@@ -996,14 +1090,14 @@ def G_fn_xi_autodiff(
9961090
xi (Array): strain vector of the robot.
9971091
9981092
Returns:
999-
Array: gravity vector of the robot.
1093+
G (Array) : gravity vector of the robot.
10001094
"""
10011095

1002-
G = jacobian(lambda xi: U_g_fn_xi(params, xi))(xi)
1096+
G = jacobian(lambda _xi: U_g_fn_xi(params, _xi))(xi)
10031097
return G
10041098

10051099
@jit
1006-
def G_fn_xi_explicit(
1100+
def G_explicit_fn(
10071101
params: Dict[str, Array],
10081102
xi: Array,
10091103
eps: float = global_eps
@@ -1085,9 +1179,9 @@ def integrand(s):
10851179
return G
10861180

10871181
if jacobian_type == "explicit":
1088-
G_fn_xi = G_fn_xi_explicit
1182+
G_fn_xi = G_explicit_fn
10891183
elif jacobian_type == "autodiff":
1090-
G_fn_xi = G_fn_xi_autodiff
1184+
G_fn_xi = G_autodiff_fn
10911185

10921186
@jit
10931187
def dynamical_matrices_fn(
@@ -1106,13 +1200,12 @@ def dynamical_matrices_fn(
11061200
eps (float, optional): small number to avoid singularities. Defaults to 1e4 * global_eps.
11071201
11081202
Returns:
1109-
Tuple:
1110-
- B (Array): mass / inertia matrix of the robot. (shape: (n_q, n_q))
1111-
- C (Array): Coriolis / centrifugal matrix of the robot. (shape: (n_q, n_q))
1112-
- G (Array): gravity vector of the robot. (shape: (n_q,))
1113-
- K (Array): elastic vector of the robot. (shape: (n_q,))
1114-
- D (Array): dissipative matrix of the robot. (shape: (n_q, n_q))
1115-
- alpha (Array): actuation matrix of the robot. (shape: (n_q, n_tau))
1203+
B (Array): mass / inertia matrix of the robot. (shape: (n_q, n_q))
1204+
C (Array): Coriolis / centrifugal matrix of the robot. (shape: (n_q, n_q))
1205+
G (Array): gravity vector of the robot. (shape: (n_q,))
1206+
K (Array): elastic vector of the robot. (shape: (n_q,))
1207+
D (Array): dissipative matrix of the robot. (shape: (n_q, n_q))
1208+
alpha (Array): actuation matrix of the robot. (shape: (n_q, n_tau))
11161209
"""
11171210
# Map the configuration to the strains
11181211
xi = xi_eq + B_xi @ q
@@ -1138,7 +1231,7 @@ def dynamical_matrices_fn(
11381231
# Apply the strain basis to the dissipative matrix
11391232
D = B_xi.T @ D @ B_xi
11401233

1141-
B, C = B_C_fn_xi(params, xi, xi_d)
1234+
B, C = B_C_fn(params, xi, xi_d)
11421235

11431236
G = B_xi.T @ G_fn_xi(params, xi).squeeze()
11441237

0 commit comments

Comments
 (0)