Skip to content

Commit df770ce

Browse files
committed
Merge branch 'main' into numerical-derivation
2 parents fc144e5 + b7944dc commit df770ce

File tree

6 files changed

+2411
-1326
lines changed

6 files changed

+2411
-1326
lines changed

examples/simulate_planar_pcs_sym.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def draw_robot(
131131
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
132132
# jit the ODE function
133133
ode_fn = jax.jit(ode_fn)
134+
# jit the ODE function
135+
ode_fn = jax.jit(ode_fn)
134136
term = ODETerm(ode_fn)
135137

136138
sol = diffeqsolve(
@@ -150,6 +152,13 @@ def draw_robot(
150152
# the evolution of the generalized velocities
151153
q_d_ts = sol.ys[:, n_q:]
152154

155+
s_max = jnp.array([jnp.sum(params["l"])])
156+
157+
forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
158+
forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector)
159+
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)
160+
161+
153162
s_max = jnp.array([jnp.sum(params["l"])])
154163

155164
forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
@@ -158,6 +167,7 @@ def draw_robot(
158167

159168
# evaluate the forward kinematics along the trajectory
160169
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
170+
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
161171
# plot the configuration vs time
162172
plt.figure()
163173
for segment_idx in range(num_segments):
@@ -215,9 +225,11 @@ def draw_robot(
215225
# plot the energy along the trajectory
216226
kinetic_energy_fn_vmapped = vmap(
217227
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
228+
partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params)
218229
)
219230
potential_energy_fn_vmapped = vmap(
220231
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)
232+
partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params)
221233
)
222234
U_ts = potential_energy_fn_vmapped(q_ts)
223235
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)

src/jsrm/math_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
from jax import numpy as jnp
22
from jax import Array, lax
33

4+
def blk_diag(
5+
a: Array
6+
) -> Array:
7+
from jax import Array, lax
8+
49
def blk_diag(
510
a: Array
611
) -> Array:
712
"""
813
Create a block diagonal matrix from a tensor of blocks.
914
15+
Create a block diagonal matrix from a tensor of blocks.
16+
1017
Args:
1118
a: matrices to be block diagonalized of shape (m, n, o)
1219
@@ -33,6 +40,7 @@ def assign_block_diagonal(i, _b):
3340
# Implement for loop to assign each block in `a` to the block-diagonal of `b`
3441
# Hint: use `jax.lax.fori_loop` and pass `assign_block_diagonal` as an argument
3542
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]), dtype=a.dtype)
43+
b = jnp.zeros((a.shape[0] * a.shape[1], a.shape[0] * a.shape[2]), dtype=a.dtype)
3644
b = lax.fori_loop(
3745
lower=0,
3846
upper=a.shape[0],

0 commit comments

Comments
 (0)