Skip to content

Commit 101def6

Browse files
committed
Merge branch 'main' into gvs
2 parents 6cae55a + b7944dc commit 101def6

File tree

6 files changed

+2393
-1360
lines changed

6 files changed

+2393
-1360
lines changed

examples/simulate_planar_pcs_sym.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def draw_robot(
131131
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
132132
# jit the ODE function
133133
ode_fn = jax.jit(ode_fn)
134+
# jit the ODE function
135+
ode_fn = jax.jit(ode_fn)
134136
term = ODETerm(ode_fn)
135137

136138
sol = diffeqsolve(
@@ -158,6 +160,7 @@ def draw_robot(
158160

159161
# evaluate the forward kinematics along the trajectory
160162
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
163+
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
161164
# plot the configuration vs time
162165
plt.figure()
163166
for segment_idx in range(num_segments):
@@ -215,6 +218,7 @@ def draw_robot(
215218
# plot the energy along the trajectory
216219
kinetic_energy_fn_vmapped = vmap(
217220
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
221+
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
218222
)
219223
potential_energy_fn_vmapped = vmap(
220224
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)

src/jsrm/math_utils.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -69,37 +69,4 @@ def blk_concat(
6969

7070
c = blk_concat(a)
7171
print("Concatenated matrix:")
72-
print(c)
73-
74-
def compute_weighted_sums(M: Array, vecm: Array, idx: int) -> Array:
75-
"""
76-
Compute the weighted sums of the matrix product of M and vecm,
77-
78-
Args:
79-
M (Array): array of shape (N, m, m)
80-
Describes the matrix to be multiplied with vecm
81-
vecm (Array): array-like of shape (N, m)
82-
Describes the vector to be multiplied with M
83-
idx (int): index of the last row to be summed over
84-
85-
Returns:
86-
Array: array of shape (N, m)
87-
The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx.
88-
"""
89-
N = M.shape[0]
90-
# Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m)
91-
prod = jnp.einsum("nij,nj->ni", M, vecm)
92-
93-
# Triangular mask for partial sum: (N, N)
94-
# mask[i, j] = 1 if j >= i and j <= idx
95-
mask = (jnp.arange(N)[:, None] <= jnp.arange(N)[None, :]) & (
96-
jnp.arange(N)[None, :] <= idx
97-
)
98-
mask = mask.astype(M.dtype) # (N, N)
99-
100-
# Extend 6-dimensional mask (N, N, 1) to apply to (N, m)
101-
masked_prod = mask[:, :, None] * prod[None, :, :] # (N, N, m)
102-
103-
# Sum over j for each i : (N, m)
104-
result = masked_prod.sum(axis=1) # (N, m)
105-
return result
72+
print(c)

0 commit comments

Comments
 (0)