Skip to content

Commit 6a455ac

Browse files
committed
GVS compatible with main branch
1 parent 101def6 commit 6a455ac

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

examples/simulate_planar_pcs_sym.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ def draw_robot(
131131
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
132132
# jit the ODE function
133133
ode_fn = jax.jit(ode_fn)
134-
# jit the ODE function
135-
ode_fn = jax.jit(ode_fn)
136134
term = ODETerm(ode_fn)
137135

138136
sol = diffeqsolve(
@@ -160,7 +158,6 @@ def draw_robot(
160158

161159
# evaluate the forward kinematics along the trajectory
162160
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
163-
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
164161
# plot the configuration vs time
165162
plt.figure()
166163
for segment_idx in range(num_segments):
@@ -218,7 +215,6 @@ def draw_robot(
218215
# plot the energy along the trajectory
219216
kinetic_energy_fn_vmapped = vmap(
220217
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
221-
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
222218
)
223219
potential_energy_fn_vmapped = vmap(
224220
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)

src/jsrm/math_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,40 @@ def blk_concat(
5757
b = a.transpose(1, 0, 2).reshape(a.shape[1], -1)
5858
return b
5959

60+
61+
def compute_weighted_sums(M: Array, vecm: Array, idx: int) -> Array:
62+
"""
63+
Compute the weighted sums of the matrix product of M and vecm,
64+
65+
Args:
66+
M (Array): array of shape (N, m, m)
67+
Describes the matrix to be multiplied with vecm
68+
vecm (Array): array-like of shape (N, m)
69+
Describes the vector to be multiplied with M
70+
idx (int): index of the last row to be summed over
71+
72+
Returns:
73+
Array: array of shape (N, m)
74+
The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx.
75+
"""
76+
N = M.shape[0]
77+
# Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m)
78+
prod = jnp.einsum("nij,nj->ni", M, vecm)
79+
80+
# Triangular mask for partial sum: (N, N)
81+
# mask[i, j] = 1 if j >= i and j <= idx
82+
mask = (jnp.arange(N)[:, None] <= jnp.arange(N)[None, :]) & (
83+
jnp.arange(N)[None, :] <= idx
84+
)
85+
mask = mask.astype(M.dtype) # (N, N)
86+
87+
# Extend 6-dimensional mask (N, N, 1) to apply to (N, m)
88+
masked_prod = mask[:, :, None] * prod[None, :, :] # (N, N, m)
89+
90+
# Sum over j for each i : (N, m)
91+
result = masked_prod.sum(axis=1) # (N, m)
92+
return result
93+
6094
if __name__ == "__main__":
6195
# Example usage
6296
a = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

src/jsrm/systems/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
def substitute_params_into_all_symbolic_expressions(
1818
sym_exps: Dict, params: Dict[str, Array]
19-
sym_exps: Dict, params: Dict[str, Array]
2019
) -> Dict:
2120
"""
2221
Substitute robot parameters into symbolic expressions.
@@ -57,7 +56,6 @@ def substitute_params_into_single_symbolic_expression(
5756
sym_exp: sp.Expr,
5857
params_syms: Dict[str, List[sp.Symbol]],
5958
params: Dict[str, Array],
60-
params: Dict[str, Array],
6159
) -> sp.Expr:
6260
"""
6361
Substitute robot parameters into a single symbolic expression.
@@ -98,7 +96,6 @@ def concatenate_params_syms(
9896

9997
def compute_strain_basis(
10098
strain_selector: Array,
101-
) -> Array:
10299
) -> Array:
103100
"""
104101
Compute strain basis based on boolean strain selector.

0 commit comments

Comments
 (0)