diff --git a/examples/simulate_pcs.py b/examples/simulate_pcs.py new file mode 100644 index 0000000..888a9b7 --- /dev/null +++ b/examples/simulate_pcs.py @@ -0,0 +1,280 @@ +import jax + +from jsrm.systems.pcs import PCS +import jax.numpy as jnp + +from typing import Callable +from jax import Array + +import numpy as onp + +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from IPython.display import HTML + +from diffrax import Tsit5 + +from functools import partial +from matplotlib.widgets import Slider + +jax.config.update("jax_enable_x64", True) # double precision +jnp.set_printoptions( + threshold=jnp.inf, + linewidth=jnp.inf, + formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"}, +) + + +def draw_robot_curve( + batched_forward_kinematics: Callable, + L_max: float, + q: Array, + num_points: int = 50, +): + s_ps = jnp.linspace(0, L_max, num_points) + g_ps = batched_forward_kinematics(q, s_ps)[:, :3, 3] + + curve = onp.array(g_ps, dtype=onp.float64) + return curve # (N, 3) + + +def animate_robot_matplotlib( + robot: PCS, + t_list: Array, # shape (T,) + q_list: Array, # shape (T, DOF) + num_points: int = 50, + interval: int = 50, + slider: bool = None, + animation: bool = None, + show: bool = True, +): + if slider is None and animation is None: + raise ValueError("Either 'slider' or 'animation' must be set to True.") + if animation and slider: + raise ValueError( + "Cannot use both animation and slider at the same time. Choose one." + ) + + batched_forward_kinematics = jax.vmap(robot.forward_kinematics, in_axes=(None, 0)) + L_max = jnp.sum(robot.L) + + width = jnp.linalg.norm(robot.L) * 3 + height = width + + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height] + + if animation: + (line,) = ax.plot([], [], [], lw=4, color="blue") + ax.set_xlim(-width / 2, width / 2) + ax.set_ylim(-width / 2, width / 2) + ax.set_zlim(0, height) + title_text = ax.set_title("t = 0.00 s") + + def init(): + line.set_data([], []) + line.set_3d_properties([]) + title_text.set_text("t = 0.00 s") + return line, title_text + + def update(frame_idx): + q = q_list[frame_idx] + t = t_list[frame_idx] + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) + line.set_data(curve[:, 0], curve[:, 1]) + line.set_3d_properties(curve[:, 2]) + title_text.set_text(f"t = {t:.2f} s") + return line, title_text + + ani = FuncAnimation( + fig, + update, + frames=len(q_list), + init_func=init, + blit=False, + interval=interval, + ) + + if show: + plt.show() + + plt.close(fig) + return HTML(ani.to_jshtml()) + + elif slider: + + def update_plot(frame_idx): + ax.cla() # Clear current axes + ax.set_xlim(-width / 2, width / 2) + ax.set_ylim(-width / 2, width / 2) + ax.set_zlim(0, height) + ax.set_xlabel("X [m]") + ax.set_ylabel("Y [m]") + ax.set_zlabel("Z [m]") + ax.set_title(f"t = {t_list[frame_idx]:.2f} s") + q = q_list[frame_idx] + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) + ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue") + fig.canvas.draw_idle() + + # Create slider + slider = Slider( + ax=ax_slider, + label="Frame", + valmin=0, + valmax=len(t_list) - 1, + valinit=0, + valstep=1, + ) + slider.on_changed(update_plot) + + update_plot(0) # Initial plot + + if show: + plt.show() + + plt.close(fig) + return HTML( + "Slider animation not implemented in HTML format. Use matplotlib directly to view the slider." + ) # Slider cannot be converted to HTML + + +if __name__ == "__main__": + num_segments = 2 + rho = 1070 * jnp.ones( + (num_segments,) + ) # Volumetric density of Dragon Skin 20 [kg/m^3] + params = { + "p0": jnp.array( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + ), # Initial position and orientation + "l": 1e-1 * jnp.ones((num_segments,)), + "r": 2e-2 * jnp.ones((num_segments,)), + "rho": rho, + "g": jnp.array([0.0, 0.0, -9.81]), # Gravity vector [m/s^2] + "E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa] + "G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa] + } + params["D"] = 1e-3 * jnp.diag( + ( + jnp.repeat( + jnp.array([[1e0, 1e0, 1e0, 1e3, 1e3, 1e3]]), num_segments, axis=0 + ) + * params["l"][:, None] + ).flatten() + ) + + # ====================================================== + # Robot initialization + # ====================================================== + robot = PCS( + num_segments=num_segments, + params=params, + order_gauss=5, + ) + + # ===================================================== + # Simulation upon time + # ===================================================== + # Initial configuration + q0 = jnp.repeat( + jnp.array([5.0 * jnp.pi, 0.0, 0.0, 0.0, 0.1, 0.2])[None, :], + num_segments, + axis=0, + ).flatten() + # Initial velocities + qd0 = jnp.zeros_like(q0) + + # Actuation parameters + tau = jnp.zeros_like(q0) + # WARNING: actuation_args need to be a tuple, even if it contains only one element + # so (tau, ) is necessary NOT (tau) or tau + actuation_args = (tau,) + + # Simulation time parameters + t0 = 0.0 + t1 = 2.0 + dt = 1e-4 + skip_step = 100 # how many time steps to skip in between video frames + + # Solver + solver = Tsit5() # Runge-Kutta 5(4) method + + ts, q_ts, q_d_ts = robot.resolve_upon_time( + q0=q0, + qd0=qd0, + actuation_args=actuation_args, + t0=t0, + t1=t1, + dt=dt, + skip_steps=skip_step, + solver=solver, + max_steps=None, + ) + + # ===================================================== + # End-effector position upon time + # ===================================================== + forward_kinematics_end_effector = jax.jit( + partial( + robot.forward_kinematics, + s=jnp.sum(robot.L), # end-effector position + ) + ) + g_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts) + + plt.figure() + plt.plot(ts, g_ee_ts[:, 0, 3], label="End-effector x [m]") + plt.plot(ts, g_ee_ts[:, 1, 3], label="End-effector y [m]") + plt.plot(ts, g_ee_ts[:, 2, 3], label="End-effector z [m]") + plt.xlabel("Time [s]") + plt.ylabel("End-effector position [m]") + plt.legend() + plt.grid(True) + plt.box(True) + plt.tight_layout() + plt.show() + + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + p = ax.scatter( + g_ee_ts[:, 0, 3], g_ee_ts[:, 1, 3], g_ee_ts[:, 2, 3], c=ts, cmap="viridis" + ) + ax.axis("equal") + ax.set_xlabel("X [m]") + ax.set_ylabel("Y [m]") + ax.set_zlabel("Z [m]") + ax.set_title("End-effector trajectory (3D)") + fig.colorbar(p, ax=ax, label="Time [s]") + plt.show() + + # ===================================================== + # Energy computation upon time + # ===================================================== + U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts) + T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts) + + plt.figure() + plt.plot(ts, U_ts, label="Potential Energy") + plt.plot(ts, T_ts, label="Kinetic Energy") + plt.xlabel("Time (s)") + plt.ylabel("Energy (J)") + plt.legend() + plt.title("Energy over Time") + plt.grid(True) + plt.box(True) + plt.tight_layout() + plt.show() + + # ===================================================== + # Plot the robot configuration upon time + # ===================================================== + animate_robot_matplotlib( + robot, + t_list=ts, # shape (T,) + q_list=q_ts, # shape (T, DOF) + num_points=50, + interval=100, # ms + slider=True, + ) diff --git a/examples/simulate_planar_pcs.py b/examples/simulate_planar_pcs.py index e71eeac..1939f67 100644 --- a/examples/simulate_planar_pcs.py +++ b/examples/simulate_planar_pcs.py @@ -1,195 +1,232 @@ -import cv2 # importing cv2 -from functools import partial import jax -jax.config.update("jax_enable_x64", True) # double precision -from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5 -from jax import Array, vmap -from jax import numpy as jnp -import matplotlib.pyplot as plt -import numpy as onp -from pathlib import Path +from jsrm.systems.planar_pcs import PlanarPCS +import jax.numpy as jnp + from typing import Callable, Dict +from jax import Array + +import numpy as onp -import jsrm -from jsrm import ode_factory -from jsrm.systems import planar_pcs +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from IPython.display import HTML -num_segments = 1 +from diffrax import Tsit5 -# filepath to symbolic expressions -sym_exp_filepath = ( - Path(jsrm.__file__).parent - / "symbolic_expressions" - / f"planar_pcs_ns-{num_segments}.dill" -) +from functools import partial +from matplotlib.widgets import Slider -# set parameters -rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3] -params = { - "th0": jnp.array(0.0), # initial orientation angle [rad] - "l": 1e-1 * jnp.ones((num_segments,)), - "r": 2e-2 * jnp.ones((num_segments,)), - "rho": rho, - "g": jnp.array([0.0, 9.81]), - "E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa] - "G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa] -} -params["D"] = 1e-3 * jnp.diag( - (jnp.repeat( - jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0 - ) * params["l"][:, None]).flatten() +jax.config.update("jax_enable_x64", True) # double precision +jnp.set_printoptions( + threshold=jnp.inf, + linewidth=jnp.inf, + formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"}, ) -# activate all strains (i.e. bending, shear, and axial) -strain_selector = jnp.ones((3 * num_segments,), dtype=bool) -# define initial configuration -q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten() -# number of generalized coordinates -n_q = q0.shape[0] +def draw_robot_curve( + batched_forward_kinematics: Callable, + L_max: float, + q: Array, + num_points: int = 50, +): + s_ps = jnp.linspace(0, L_max, num_points) + chi_ps = batched_forward_kinematics(q, s_ps) -# set simulation parameters -dt = 1e-4 # time step -ts = jnp.arange(0.0, 2, dt) # time steps -skip_step = 10 # how many time steps to skip in between video frames -video_ts = ts[::skip_step] # time steps for video + curve = onp.array(chi_ps[1:, :], dtype=onp.float32).T -# video settings -video_width, video_height = 700, 700 # img height and width -video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4" + return curve # (N, 2) -def draw_robot( - batched_forward_kinematics_fn: Callable, - params: Dict[str, Array], - q: Array, - width: int, - height: int, +def animate_robot_matplotlib( + robot: PlanarPCS, + t_list: Array, # shape (T,) + q_list: Array, # shape (T, DOF) num_points: int = 50, -) -> onp.ndarray: - # plotting in OpenCV - h, w = height, width # img height and width - ppm = h / (2.0 * jnp.sum(params["l"])) # pixel per meter - base_color = (0, 0, 0) # black robot_color in BGR - robot_color = (255, 0, 0) # black robot_color in BGR - - # we use for plotting N points along the length of the robot - s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points) - - # poses along the robot of shape (3, N) - chi_ps = batched_forward_kinematics_fn(params, q, s_ps) - - img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white - curve_origin = onp.array( - [w // 2, 0.1 * h], dtype=onp.int32 - ) # in x-y pixel coordinates - # draw base - cv2.rectangle(img, (0, h - curve_origin[1]), (w, h), color=base_color, thickness=-1) - # transform robot poses to pixel coordinates - # should be of shape (N, 2) - curve = onp.array((curve_origin + chi_ps[:2, :].T * ppm), dtype=onp.int32) - # invert the v pixel coordinate - curve[:, 1] = h - curve[:, 1] - cv2.polylines(img, [curve], isClosed=False, color=robot_color, thickness=10) - - return img + interval: int = 50, + slider: bool = None, + animation: bool = None, + show: bool = True, +): + if slider is None and animation is None: + raise ValueError("Either 'slider' or 'animation' must be set to True.") + if animation and slider: + raise ValueError( + "Cannot use both animation and slider at the same time. Choose one." + ) + + batched_forward_kinematics = jax.vmap( + robot.forward_kinematics, in_axes=(None, 0), out_axes=-1 + ) + L_max = jnp.sum(robot.L) + + width = jnp.linalg.norm(robot.L) * 3 + height = width + + fig = plt.figure() + ax = fig.add_subplot(111) + ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height] + + if animation: + (line,) = ax.plot([], [], lw=4, color="blue") + ax.set_xlim(-width / 2, width / 2) + ax.set_ylim(0, height) + title_text = ax.set_title("t = 0.00 s") + + def init(): + line.set_data([], []) + title_text.set_text("t = 0.00 s") + return line, title_text + + def update(frame_idx): + q = q_list[frame_idx] + t = t_list[frame_idx] + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) + line.set_data(curve[:, 0], curve[:, 1]) + title_text.set_text(f"t = {t:.2f} s") + return line, title_text + + ani = FuncAnimation( + fig, + update, + frames=len(q_list), + init_func=init, + blit=False, + interval=interval, + ) + + if show: + plt.show() + plt.close(fig) + return HTML(ani.to_jshtml()) + + elif slider: + + def update_plot(frame_idx): + ax.cla() # Clear current axes + ax.set_xlim(-width / 2, width / 2) + ax.set_ylim(0, height) + ax.set_xlabel("X [m]") + ax.set_ylabel("Y [m]") + ax.set_title(f"t = {t_list[frame_idx]:.2f} s") + q = q_list[frame_idx] + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) + ax.plot(curve[:, 0], curve[:, 1], lw=4, color="blue") + fig.canvas.draw_idle() + + # Create slider + slider = Slider( + ax=ax_slider, + label="Frame", + valmin=0, + valmax=len(t_list) - 1, + valinit=0, + valstep=1, + ) + slider.on_changed(update_plot) + + update_plot(0) # Initial plot + + if show: + plt.show() + + plt.close(fig) + return HTML( + "Slider animation not implemented in HTML format. Use matplotlib directly to view the slider." + ) # Slider cannot be converted to HTML if __name__ == "__main__": - strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = ( - planar_pcs.factory(sym_exp_filepath, strain_selector) + num_segments = 2 + rho = 1070 * jnp.ones( + (num_segments,) + ) # Volumetric density of Dragon Skin 20 [kg/m^3] + params = { + "th0": jnp.array(0.0), # initial orientation angle [rad] + "l": 1e-1 * jnp.ones((num_segments,)), + "r": 2e-2 * jnp.ones((num_segments,)), + "rho": rho, + "g": jnp.array([0.0, -9.81]), + "E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa] + "G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa] + } + params["D"] = 1e-3 * jnp.diag( + ( + jnp.repeat(jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0) + * params["l"][:, None] + ).flatten() ) - # jit the functions - dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn)) - batched_forward_kinematics = vmap( - forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1 + + # ====================================================== + # Robot initialization + # ====================================================== + robot = PlanarPCS( + num_segments=num_segments, + params=params, + order_gauss=5, ) - # import matplotlib.pyplot as plt - # plt.plot(chi_ps[0, :], chi_ps[1, :]) - # plt.axis("equal") - # plt.grid(True) - # plt.xlabel("x [m]") - # plt.ylabel("y [m]") - # plt.show() - - # Displaying the image - # window_name = f"Planar PCS with {num_segments} segments" - # img = draw_robot(batched_forward_kinematics, params, q0, video_width, video_height) - # cv2.namedWindow(window_name) - # cv2.imshow(window_name, img) - # cv2.waitKey() - # cv2.destroyWindow(window_name) - - x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition - tau = jnp.zeros_like(q0) # torques - - ode_fn = ode_factory(dynamical_matrices_fn, params, tau) - # jit the ODE function - ode_fn = jax.jit(ode_fn) - term = ODETerm(ode_fn) - - sol = diffeqsolve( - term, - solver=Tsit5(), - t0=ts[0], - t1=ts[-1], - dt0=dt, - y0=x0, + # ===================================================== + # Simulation upon time + # ===================================================== + # Initial configuration + q0 = jnp.repeat( + jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0 + ).flatten() + # Initial velocities + qd0 = jnp.zeros_like(q0) + + # Actuation parameters + tau = jnp.zeros_like(q0) + # WARNING: actuation_args need to be a tuple, even if it contains only one element + actuation_args = (tau,) + + # Simulation time parameters + t0 = 0.0 + t1 = 2.0 + dt = 1e-4 + skip_step = 100 # how many time steps to skip in between video frames + + # Solver + solver = Tsit5() # Runge-Kutta 5(4) method + + ts, q_ts, q_d_ts = robot.resolve_upon_time( + q0=q0, + qd0=qd0, + actuation_args=actuation_args, + t0=t0, + t1=t1, + dt=dt, + skip_steps=skip_step, + solver=solver, max_steps=None, - saveat=SaveAt(ts=video_ts), ) - print("sol.ys =\n", sol.ys) - # the evolution of the generalized coordinates - q_ts = sol.ys[:, :n_q] - # the evolution of the generalized velocities - q_d_ts = sol.ys[:, n_q:] - - s_max = jnp.array([jnp.sum(params["l"])]) - - forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max) - forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector) - forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector) - - # evaluate the forward kinematics along the trajectory - chi_ee_ts = forward_kinematics_fn_end_effector(q_ts) - # plot the configuration vs time - plt.figure() - for segment_idx in range(num_segments): - plt.plot( - video_ts, q_ts[:, 3 * segment_idx + 0], - label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]" - ) - plt.plot( - video_ts, q_ts[:, 3 * segment_idx + 1], - label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]" + # ===================================================== + # End-effector position upon time + # ===================================================== + forward_kinematics_end_effector = jax.jit( + partial( + robot.forward_kinematics, + s=jnp.sum(robot.L), # end-effector position ) - plt.plot( - video_ts, q_ts[:, 3 * segment_idx + 2], - label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]" - ) - plt.xlabel("Time [s]") - plt.ylabel("Configuration") - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.show() - # plot end-effector position vs time + ) + chi_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts) + plt.figure() - plt.plot(video_ts, chi_ee_ts[:, 0], label="x") - plt.plot(video_ts, chi_ee_ts[:, 1], label="y") + plt.plot(ts, chi_ee_ts[:, 1], label="End-effector x [m]") + plt.plot(ts, chi_ee_ts[:, 2], label="End-effector y [m]") plt.xlabel("Time [s]") - plt.ylabel("End-effector Position [m]") + plt.ylabel("End-effector position [m]") plt.legend() plt.grid(True) plt.box(True) plt.tight_layout() plt.show() - # plot the end-effector position in the x-y plane as a scatter plot with the time as the color + plt.figure() - plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis") + plt.scatter(chi_ee_ts[:, 1], chi_ee_ts[:, 2], c=ts, cmap="viridis") plt.axis("equal") plt.grid(True) plt.xlabel("End-effector x [m]") @@ -197,55 +234,33 @@ def draw_robot( plt.colorbar(label="Time [s]") plt.tight_layout() plt.show() - # plt.figure() - # plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1]) - # plt.axis("equal") - # plt.grid(True) - # plt.xlabel("End-effector x [m]") - # plt.ylabel("End-effector y [m]") - # plt.tight_layout() - # plt.show() - - # plot the energy along the trajectory - kinetic_energy_fn_vmapped = vmap( - partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params) - ) - potential_energy_fn_vmapped = vmap( - partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params) - ) - U_ts = potential_energy_fn_vmapped(q_ts) - T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts) + + # ===================================================== + # Energy computation upon time + # ===================================================== + U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts) + T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts) + plt.figure() - plt.plot(video_ts, U_ts, label="Potential energy") - plt.plot(video_ts, T_ts, label="Kinetic energy") - plt.xlabel("Time [s]") - plt.ylabel("Energy [J]") + plt.plot(ts, U_ts, label="Potential Energy") + plt.plot(ts, T_ts, label="Kinetic Energy") + plt.xlabel("Time (s)") + plt.ylabel("Energy (J)") plt.legend() + plt.title("Energy over Time") plt.grid(True) plt.box(True) plt.tight_layout() plt.show() - # create video - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - video_path.parent.mkdir(parents=True, exist_ok=True) - video = cv2.VideoWriter( - str(video_path), - fourcc, - 1 / (skip_step * dt), # fps - (video_width, video_height), + # ===================================================== + # Plot the robot configuration upon time + # ===================================================== + animate_robot_matplotlib( + robot=robot, + t_list=ts, # shape (T,) + q_list=q_ts, # shape (T, DOF) + num_points=50, + interval=100, # ms + slider=True, ) - - for time_idx, t in enumerate(video_ts): - x = sol.ys[time_idx] - img = draw_robot( - batched_forward_kinematics, - params, - x[: (x.shape[0] // 2)], - video_width, - video_height, - ) - video.write(img) - - video.release() - print(f"Video saved at {video_path}") diff --git a/examples/simulate_planar_pcs_sym.py b/examples/simulate_planar_pcs_sym.py new file mode 100644 index 0000000..0b84f83 --- /dev/null +++ b/examples/simulate_planar_pcs_sym.py @@ -0,0 +1,259 @@ +import cv2 # importing cv2 +from functools import partial +import jax + +jax.config.update("jax_enable_x64", True) # double precision +from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5 +from jax import Array, vmap +from jax import numpy as jnp +import matplotlib.pyplot as plt +import numpy as onp +from pathlib import Path +from typing import Callable, Dict + +import jsrm +from jsrm import ode_factory +from jsrm.systems import planar_pcs_sym + +num_segments = 1 + +# filepath to symbolic expressions +sym_exp_filepath = ( + Path(jsrm.__file__).parent + / "symbolic_expressions" + / f"planar_pcs_ns-{num_segments}.dill" +) + +# set parameters +rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3] +params = { + "th0": jnp.array(0.0), # initial orientation angle [rad] + "l": 1e-1 * jnp.ones((num_segments,)), + "r": 2e-2 * jnp.ones((num_segments,)), + "rho": rho, + "g": jnp.array([0.0, 9.81]), + "E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa] + "G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa] +} +params["D"] = 1e-3 * jnp.diag( + ( + jnp.repeat(jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0) + * params["l"][:, None] + ).flatten() +) + +# activate all strains (i.e. bending, shear, and axial) +strain_selector = jnp.ones((3 * num_segments,), dtype=bool) + +# define initial configuration +q0 = jnp.repeat( + jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0 +).flatten() +# number of generalized coordinates +n_q = q0.shape[0] + +# set simulation parameters +dt = 1e-4 # time step +ts = jnp.arange(0.0, 2, dt) # time steps +skip_step = 10 # how many time steps to skip in between video frames +video_ts = ts[::skip_step] # time steps for video + +# video settings +video_width, video_height = 700, 700 # img height and width +video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4" + + +def draw_robot( + batched_forward_kinematics_fn: Callable, + params: Dict[str, Array], + q: Array, + width: int, + height: int, + num_points: int = 50, +) -> onp.ndarray: + # plotting in OpenCV + h, w = height, width # img height and width + ppm = h / (2.0 * jnp.sum(params["l"])) # pixel per meter + base_color = (0, 0, 0) # black robot_color in BGR + robot_color = (255, 0, 0) # black robot_color in BGR + + # we use for plotting N points along the length of the robot + s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points) + + # poses along the robot of shape (3, N) + chi_ps = batched_forward_kinematics_fn(params, q, s_ps) + + img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white + curve_origin = onp.array( + [w // 2, 0.1 * h], dtype=onp.int32 + ) # in x-y pixel coordinates + # draw base + cv2.rectangle(img, (0, h - curve_origin[1]), (w, h), color=base_color, thickness=-1) + # transform robot poses to pixel coordinates + # should be of shape (N, 2) + curve = onp.array((curve_origin + chi_ps[:2, :].T * ppm), dtype=onp.int32) + # invert the v pixel coordinate + curve[:, 1] = h - curve[:, 1] + cv2.polylines(img, [curve], isClosed=False, color=robot_color, thickness=10) + + return img + + +if __name__ == "__main__": + strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = ( + planar_pcs_sym.factory(sym_exp_filepath, strain_selector) + ) + # jit the functions + dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn)) + batched_forward_kinematics = vmap( + forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1 + ) + + # import matplotlib.pyplot as plt + # plt.plot(chi_ps[0, :], chi_ps[1, :]) + # plt.axis("equal") + # plt.grid(True) + # plt.xlabel("x [m]") + # plt.ylabel("y [m]") + # plt.show() + + # Displaying the image + # window_name = f"Planar PCS with {num_segments} segments" + # img = draw_robot(batched_forward_kinematics, params, q0, video_width, video_height) + # cv2.namedWindow(window_name) + # cv2.imshow(window_name, img) + # cv2.waitKey() + # cv2.destroyWindow(window_name) + + x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition + tau = jnp.zeros_like(q0) # torques + + ode_fn = ode_factory(dynamical_matrices_fn, params, tau) + # jit the ODE function + ode_fn = jax.jit(ode_fn) + # jit the ODE function + ode_fn = jax.jit(ode_fn) + term = ODETerm(ode_fn) + + sol = diffeqsolve( + term, + solver=Tsit5(), + t0=ts[0], + t1=ts[-1], + dt0=dt, + y0=x0, + max_steps=None, + saveat=SaveAt(ts=video_ts), + ) + + print("sol.ys =\n", sol.ys) + # the evolution of the generalized coordinates + q_ts = sol.ys[:, :n_q] + # the evolution of the generalized velocities + q_d_ts = sol.ys[:, n_q:] + + s_max = jnp.array([jnp.sum(params["l"])]) + + forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max) + forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector) + forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector) + + # evaluate the forward kinematics along the trajectory + chi_ee_ts = forward_kinematics_fn_end_effector(q_ts) + # plot the configuration vs time + plt.figure() + for segment_idx in range(num_segments): + plt.plot( + video_ts, + q_ts[:, 3 * segment_idx + 0], + label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]", + ) + plt.plot( + video_ts, + q_ts[:, 3 * segment_idx + 1], + label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]", + ) + plt.plot( + video_ts, + q_ts[:, 3 * segment_idx + 2], + label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]", + ) + plt.xlabel("Time [s]") + plt.ylabel("Configuration") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.show() + # plot end-effector position vs time + plt.figure() + plt.plot(video_ts, chi_ee_ts[:, 0], label="x") + plt.plot(video_ts, chi_ee_ts[:, 1], label="y") + plt.xlabel("Time [s]") + plt.ylabel("End-effector Position [m]") + plt.legend() + plt.grid(True) + plt.box(True) + plt.tight_layout() + plt.show() + # plot the end-effector position in the x-y plane as a scatter plot with the time as the color + plt.figure() + plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis") + plt.axis("equal") + plt.grid(True) + plt.xlabel("End-effector x [m]") + plt.ylabel("End-effector y [m]") + plt.colorbar(label="Time [s]") + plt.tight_layout() + plt.show() + # plt.figure() + # plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1]) + # plt.axis("equal") + # plt.grid(True) + # plt.xlabel("End-effector x [m]") + # plt.ylabel("End-effector y [m]") + # plt.tight_layout() + # plt.show() + + # plot the energy along the trajectory + kinetic_energy_fn_vmapped = vmap( + partial(jax.jit(auxiliary_fns["kinetic_energy_fn"]), params) + ) + potential_energy_fn_vmapped = vmap( + partial(jax.jit(auxiliary_fns["potential_energy_fn"]), params) + ) + U_ts = potential_energy_fn_vmapped(q_ts) + T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts) + plt.figure() + plt.plot(video_ts, U_ts, label="Potential energy") + plt.plot(video_ts, T_ts, label="Kinetic energy") + plt.xlabel("Time [s]") + plt.ylabel("Energy [J]") + plt.legend() + plt.grid(True) + plt.box(True) + plt.tight_layout() + plt.show() + + # create video + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video_path.parent.mkdir(parents=True, exist_ok=True) + video = cv2.VideoWriter( + str(video_path), + fourcc, + 1 / (skip_step * dt), # fps + (video_width, video_height), + ) + + for time_idx, t in enumerate(video_ts): + x = sol.ys[time_idx] + img = draw_robot( + batched_forward_kinematics, + params, + x[: (x.shape[0] // 2)], + video_width, + video_height, + ) + video.write(img) + + video.release() + print(f"Video saved at {video_path}") diff --git a/pyproject.toml b/pyproject.toml index c13dec0..c8b08b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ name = "jsrm" # Required # # For a discussion on single-sourcing the version, see # https://packaging.python.org/guides/single-sourcing-package-version/ -version = "0.0.17" # Required +version = "0.1.0" # Required # This is a one-line description or tagline of what your project does. This # corresponds to the "Summary" metadata field: @@ -107,6 +107,7 @@ dependencies = [ # Optional "jax", "numpy", "quadax", + "equinox", "peppercorn", "sympy>=1.11" ] diff --git a/src/jsrm/math_utils.py b/src/jsrm/math_utils.py index 8088921..e5323c6 100644 --- a/src/jsrm/math_utils.py +++ b/src/jsrm/math_utils.py @@ -1,9 +1,8 @@ from jax import numpy as jnp from jax import Array, lax -def blk_diag( - a: Array -) -> Array: + +def blk_diag(a: Array) -> Array: """ Create a block diagonal matrix from a tensor of blocks. @@ -42,9 +41,8 @@ def assign_block_diagonal(i, _b): return b -def blk_concat( - a: Array -) -> Array: + +def blk_concat(a: Array) -> Array: """ Concatenate horizontally (along the columns) a list of N matrices of size (m, n) to create a single matrix of size (m, n * N). @@ -57,16 +55,36 @@ def blk_concat( b = a.transpose(1, 0, 2).reshape(a.shape[1], -1) return b -if __name__ == "__main__": - # Example usage - a = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - print("Original array:") - print(a) - - b = blk_diag(a) - print("Block diagonal matrix:") - print(b) - - c = blk_concat(a) - print("Concatenated matrix:") - print(c) \ No newline at end of file + +def compute_weighted_sums(M: Array, vecm: Array, idx: int) -> Array: + """ + Compute the weighted sums of the matrix product of M and vecm, + + Args: + M (Array): array of shape (N, m, m) + Describes the matrix to be multiplied with vecm + vecm (Array): array-like of shape (N, m) + Describes the vector to be multiplied with M + idx (int): index of the last row to be summed over + + Returns: + Array: array of shape (N, m) + The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx. + """ + N = M.shape[0] + # Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m) + prod = jnp.einsum("nij,nj->ni", M, vecm) + + # Triangular mask for partial sum: (N, N) + # mask[i, j] = 1 if j >= i and j <= idx + mask = (jnp.arange(N)[:, None] <= jnp.arange(N)[None, :]) & ( + jnp.arange(N)[None, :] <= idx + ) + mask = mask.astype(M.dtype) # (N, N) + + # Extend 6-dimensional mask (N, N, 1) to apply to (N, m) + masked_prod = mask[:, :, None] * prod[None, :, :] # (N, N, m) + + # Sum over j for each i : (N, m) + result = masked_prod.sum(axis=1) # (N, m) + return result diff --git a/src/jsrm/systems/__init__.py b/src/jsrm/systems/__init__.py index ede258c..e69de29 100644 --- a/src/jsrm/systems/__init__.py +++ b/src/jsrm/systems/__init__.py @@ -1 +0,0 @@ -import jsrm.systems.planar_pcs_sym as planar_pcs diff --git a/src/jsrm/systems/pcs.py b/src/jsrm/systems/pcs.py new file mode 100644 index 0000000..8f96cdf --- /dev/null +++ b/src/jsrm/systems/pcs.py @@ -0,0 +1,1486 @@ +from jax import Array, lax, vmap +from jax import numpy as jnp + +from typing import Callable, Dict, Tuple, Optional + +import equinox as eqx + +from .utils import ( + compute_strain_basis, + compute_spatial_stiffness_matrix, + gauss_quadrature, + scale_gaussian_quadrature, +) +from jsrm.math_utils import ( + blk_diag, + compute_weighted_sums, +) +import jsrm.utils.lie_algebra as lie + +from diffrax import ( + diffeqsolve, + ODETerm, + SaveAt, + Tsit5, + PIDController, + ConstantStepSize, + AbstractSolver, +) + + +class PCS(eqx.Module): + """ + Piecewise Constant Strain (PCS) model for 3D soft continuum robots. + + This class implements the geometric and dynamic modeling of a 3D soft robot + using the Cosserat rod theory and piecewise constant strain assumption. + It supports computation of forward kinematics, Jacobians, dynamical matrices. + + Attributes: + ---------- + num_segments : int + Number of segments (constant strain sections) along the robot. + g0 : Array + Initial pose of the robot base as an SE(3) transformation matrix. + g : Array + Gravitational acceleration vector (embedded in a 6D vector). + [0, 0, 0, g_x, g_y, g_z] + L, r, E, G, rho, D : Array + Physical properties of each segment (length, radius, elastic/shear modulus, etc.). + num_active_strains : int + Number of active strain components (based on strain_selector). + num_strains : int + Total number of strain components (6 * num_segments). + B_xi : Array + Basis matrix for projecting active strains. + xi_star : Array + Rest strain (reference configuration) of the robot. + num_gauss_points : int + Number of points used for numerical integration. + Corresponds to the order of Gauss-Legendre quadrature + 2 (for the endpoints). + Xs, Ws : Array + Gauss-Legendre quadrature nodes and weights for numerical integration. + stiffness_fn : Callable + Function to compute the full stiffness matrix. + actuation_mapping_fn : Callable + Function to map actuation torques into strain space. + + Methods: + ------- + strain(q: Array) -> Array: + Computes the strain vector from generalized coordinates. + forward_kinematics(q: Array, s: Array) -> Array: + Computes the forward kinematics at a point s along the robot. + jacobian(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (global frame). + jacobian_and_derivative(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (global frame). + dynamical_matrix(q: Array, qd: Array, actuation_args: Tuple[Array]) -> Array: + Computes the dynamical matrix for the system. + resolve_upon_time(q0: Array, qd0: Array, actuation_args: Tuple[Array], t0: float, t1: float, dt: float, skip_steps: int, solver: AbstractSolver, max_steps: Optional[int] = None) -> Tuple[Array, Array, Array]: + Simulates the robot dynamics over time using the specified solver. + forward_dynamics(t: float, y: Array, actuation_args: Optional[Tuple]) -> Array: + Computes the forward dynamics of the system at a given time t. + + kinetic_energy(q: Array, qd: Array) -> float: + Computes the kinetic energy of the system. + elastic_energy(q: Array) -> float: + Computes the elastic potential energy of the system. + gravitational_energy(q: Array) -> float: + Computes the gravitational potential energy of the system. + potential_energy(q: Array) -> float: + Computes the total potential energy (elastic + gravitational) of the system. + total_energy(q: Array, qd: Array) -> float: + Computes the total energy (kinetic + potential) of the system. + + operational_space_dynamical_matrices(q: Array, qd: Array, s: Array, operational_space_selector: Optional[Tuple]) -> Tuple[Array, Array, Array, Array, Array]: + # TODO + + inertia_matrix(q: Array) -> Array: + Computes the inertia matrix of the system. + coriolis_matrix(q: Array, qd: Array) -> Array: + Computes the Coriolis matrix of the system. + gravitational_vector(q: Array) -> Array: + Computes the gravitational force vector acting on the system. + stiffness_matrix() -> Array: + Computes the stiffness matrix of the system. + damping_matrix() -> Array: + Computes the damping matrix of the system. + actuation_mapping(q: Array, actuation_args: Tuple[Array]) -> Array: + Computes the actuation mapping of the system. + + classify_segment(s: Array) -> Tuple[Array, Array]: + Classifies a point along the robot to its corresponding segment and local coordinate. + jacobian_bodyframe(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (body frame). + jacobian_inertialframe(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (inertial frame). + jacobian_and_derivative_bodyframe(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (body frame). + jacobian_and_derivative_inertialframe(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (inertial frame). + + Notes: + ----- + - The strain vector is composed of 6 components per segment: + [kappa_x, kappa_y, kappa_z, sigma_x, sigma_y, sigma_z]. + By default, the rod is assumed to be straight and aligned with the z-axis, + so the rest strain is set to [0, 0, 0, 0, 0, 1]. + Thus: - kappa_x corresponds to bending around the x-axis, + - kappa_y corresponds to bending around the y-axis, + - kappa_z corresponds to torsion around the z-axis, + - sigma_x corresponds to shear along the x-axis, + - sigma_y corresponds to shear along the y-axis, + - sigma_z corresponds to axial strain along the z-axis. + + """ + + # Robot parameters + g0: Array # Initial position and orientation of the rod + g: Array # Gravitational acceleration vector + + L: Array # Length of the segments + L_cum: Array # Cumulative length of the segments + r: Array # Radius of the segments + rho: Array + E: Array # Young's modulus of the segments + G: Array # Shear modulus of the segments + D: Array # Damping coefficient of the segments + + global_eps: float = jnp.finfo(jnp.float64).eps + + stiffness_fn: Callable = eqx.static_field() + actuation_mapping_fn: Callable = eqx.static_field() + + num_segments: int = eqx.static_field() + num_gauss_points: int = eqx.static_field() # + num_strains: int = eqx.static_field() # Number of strains (6 * num_segments) + + xi_star: Array # Rest configuration strain + B_xi: Array # Strain basis matrix + + num_active_strains: Array # Number of selected strains + + Xs: Array # Gauss nodes + Ws: Array # Gauss weights + + def __init__( + self, + num_segments: int, + params: Dict[str, Array], + order_gauss: int = 5, + strain_selector: Optional[Array] = None, + xi_star: Optional[Array] = None, + stiffness_fn: Optional[Callable] = None, + actuation_mapping_fn: Optional[Callable] = None, + ) -> "PCS": + """ + Initialize the PCS class. + + Args: + num_segments (int): + Number of segments in the robot. + params (Dict[str, Array]): + Dictionary containing the robot parameters: + - "p0": Initial orientation angle and position in the inertial frame [rad, m] + [ψ, θ, φ, x0, y0, z0] + where [ψ, θ, φ] are the Euler angles in the ZXZ convention: + ψ (psi) : Rotation around Z axis (fixed axis) + θ (thêta) : Rotation around X' axis (movable axis after first rotation) + φ (phi) : Rotation about the Z' axis (movable axis after the first two rotations) + [x0, y0, z0] : Position of the robot in the inertial frame + - "l": Length of each segment [m] + - "r": Radius of each segment [m] + - "rho": Density of each segment [kg/m^3] + - "g": Gravitational acceleration vector [m/s^2] + - "E": Elastic modulus of each segment [Pa] + - "G": Shear modulus of each segment [Pa] + order_gauss (int, optional): + Order of the Gauss-Legendre quadrature for integration over each segment. + Defaults to 5. + strain_selector (Optional[Array], optional): + Boolean array of shape (6 * num_segments,) specifying which strain components are active. + Defaults to all strains active (i.e. all True). + xi_star (Optional[Array], optional): + Rest strain of shape (6 * num_segments,). + Defaults to 0.0 for bending and shear strains, and 1.0 for axial strain (along local z-axis). + stiffness_fn (Optional[Callable], optional): + Function to compute the stiffness matrix. + Defaults to : + l_i * diag( E_i * Ib_i, # bending X + E_i * Ib_i, # bending Y + G_i * J_i, # torsion Z + 4/3 * A_i * G_i, # shear X + 4/3 * A_i * G_i, # shear Y + A_i * E_i, # axial Z) + actuation_mapping_fn (Optional[Callable], optional): + Function to compute the actuation mapping. + This function needs to take as input: + - q: generalized coordinates of shape (num_active_strains,) + - actuation_args: tuple containing the actuation parameters (e.g. torques (tau,)). + Defaults to identity linear mapping. actuation_args = (tau,) + + """ + # Number of segments + if not isinstance(num_segments, int): + raise TypeError( + f"num_segments must be an integer, got {type(num_segments).__name__}" + ) + if num_segments < 1: + raise ValueError(f"num_segments must be at least 1, got {num_segments}") + self.num_segments = num_segments + + num_strains = 6 * num_segments + self.num_strains = num_strains + + # ================================================================ + # Robot parameters + + # Initial position and orientation angle + try: + p0 = params["p0"] + except KeyError: + raise KeyError("Parameter 'p0' is required in params dictionary.") + # if not (isinstance(p0, (float, int, jnp.ndarray))): + # raise TypeError( + # f"p0 must be a float, int, or an array, got {type(th0).__name__}" + # ) + p0 = jnp.asarray(p0, dtype=jnp.float64) + self.g0 = lie.exp_SE3(p0) + + # Gravitational acceleration vector + try: + g = params["g"] + except KeyError: + raise KeyError("Parameter 'g' is required in params dictionary.") + if not (isinstance(g, (list, jnp.ndarray))): + raise TypeError(f"g must be a list or an array, got {type(g).__name__}") + g = jnp.asarray(g, dtype=jnp.float64) + if g.size != 3: + raise ValueError(f"g must be a vector of shape (3,), got {g.size}") + self.g = jnp.concatenate( + [jnp.zeros(3), g] + ) # Add zeros for the orientation angles + + # Lengths of the segments + try: + L = params["l"] + except KeyError: + raise KeyError("Parameter 'l' is required in params dictionary.") + if not (isinstance(L, (list, jnp.ndarray))): + raise TypeError(f"l must be a list or an array, got {type(L).__name__}") + L = jnp.asarray(L, dtype=jnp.float64) + if L.shape != (num_segments,): + raise ValueError(f"l must have shape ({num_segments},), got {L.shape}") + self.L = L + + L_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), self.L])) + self.L_cum = L_cum + + # Radius of the segments + try: + r = params["r"] + except KeyError: + raise KeyError("Parameter 'r' is required in params dictionary.") + if not (isinstance(r, (list, jnp.ndarray))): + raise TypeError(f"r must be a list or an array, got {type(r).__name__}") + r = jnp.asarray(r, dtype=jnp.float64) + if r.shape != (num_segments,): + raise ValueError(f"r must have shape ({num_segments},), got {r.shape}") + self.r = r + + # Densities of the segments + try: + rho = params["rho"] + except KeyError: + raise KeyError("Parameter 'rho' is required in params dictionary.") + if not (isinstance(rho, (list, jnp.ndarray))): + raise TypeError(f"rho must be a list or an array, got {type(rho).__name__}") + rho = jnp.asarray(rho, dtype=jnp.float64) + if rho.shape != (num_segments,): + raise ValueError(f"rho must have shape ({num_segments},), got {rho.shape}") + self.rho = rho + + # Elastic modulus of the segments + try: + E = params["E"] + except KeyError: + raise KeyError("Parameter 'E' is required in params dictionary.") + if not (isinstance(E, (list, jnp.ndarray))): + raise TypeError(f"E must be a list or an array, got {type(E).__name__}") + E = jnp.asarray(E, dtype=jnp.float64) + if E.shape != (num_segments,): + raise ValueError(f"E must have shape ({num_segments},), got {E.shape}") + self.E = E + + # Shear modulus of the segments + try: + G = params["G"] + except KeyError: + raise KeyError("Parameter 'G' is required in params dictionary.") + if not (isinstance(G, (list, jnp.ndarray))): + raise TypeError(f"G must be a list or an array, got {type(G).__name__}") + G = jnp.asarray(G, dtype=jnp.float64) + if G.shape != (num_segments,): + raise ValueError(f"G must have shape ({num_segments},), got {G.shape}") + self.G = G + + # Damping matrix of the robot + try: + D = params["D"] + except KeyError: + raise KeyError("Parameter 'D' is required in params dictionary.") + if not (isinstance(D, (list, jnp.ndarray))): + raise TypeError(f"D must be a list or an array, got {type(D).__name__}") + D = jnp.asarray(D, dtype=jnp.float64) + expected_D_shape = (num_strains, num_strains) + if D.shape != expected_D_shape: + raise ValueError(f"D must have shape {expected_D_shape}, got {D.shape}") + self.D = D + + # ================================================================ + # Order of Gauss-Legendre quadrature + if not isinstance(order_gauss, int): + raise TypeError( + f"order_gauss must be an integer, got {type(order_gauss).__name__}" + ) + if order_gauss < 1: + raise ValueError(f"param_integration must be at least 1, got {order_gauss}") + Xs, Ws, num_gauss_points = gauss_quadrature(order_gauss, a=0.0, b=1.0) + self.Xs = Xs + self.Ws = Ws + self.num_gauss_points = num_gauss_points + + # ================================================================ + # Strain basis matrix + if strain_selector is None: + strain_selector = jnp.ones(num_strains, dtype=bool) + else: + if not isinstance(strain_selector, (list, jnp.ndarray)): + raise TypeError( + f"strain_selector must be a list or an array, got {type(strain_selector).__name__}" + ) + strain_selector = jnp.asarray(strain_selector) + if not jnp.issubdtype(strain_selector.dtype, jnp.bool_): + raise TypeError( + f"strain_selector must be a boolean array, got {strain_selector.dtype}" + ) + if strain_selector.size != num_strains: + raise ValueError( + f"strain_selector must have {num_strains} elements, got {strain_selector.size}" + ) + strain_selector = strain_selector.reshape(num_strains) + self.B_xi = compute_strain_basis(strain_selector) + + self.num_active_strains = jnp.sum(strain_selector) + + # Rest configuration strain + if xi_star is None: + xi_star = jnp.tile( + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0], dtype=jnp.float64), + (num_segments, 1), + ).reshape(num_strains) + else: + if not isinstance(xi_star, (list, jnp.ndarray)): + raise TypeError( + f"xi_star must be a list or an array, got {type(xi_star).__name__}" + ) + xi_star = jnp.asarray(xi_star) + if xi_star.size != num_strains: + raise ValueError( + f"xi_star must have {num_strains} elements, got {xi_star.size}" + ) + xi_star = xi_star.reshape(num_strains) + self.xi_star = xi_star + + # Stiffness function + if stiffness_fn is None: + compute_stiffness_matrix_for_all_segments_fn = vmap( + compute_spatial_stiffness_matrix + ) + + def stiffness_fn( + formulate_in_strain_space: bool = False, + ) -> Array: + L = self.L + r = self.r + E = self.E + G = self.G + + # cross-sectional area and second moment of area + A = jnp.pi * r**2 + Ib = A**2 / (4 * jnp.pi) + J = jnp.pi * r**4 / 2 # Polar moment of inertia + + # stiffness matrix of shape (num_segments, 6, 6) + S_sms = compute_stiffness_matrix_for_all_segments_fn(L, A, Ib, J, E, G) + # we define the elastic matrix of shape (num_strains, num_strains) as K(xi) = K @ xi where K is equal to + S = blk_diag(S_sms) + + if not formulate_in_strain_space: + S = self.B_xi.T @ S @ self.B_xi + + return S + else: + if not callable(stiffness_fn): + raise TypeError( + f"stiffness_fn must be a callable, got {type(stiffness_fn).__name__}" + ) + self.stiffness_fn = stiffness_fn + + # Actuation mapping function + if actuation_mapping_fn is None: + + def actuation_mapping_fn(q: Array, tau: Array) -> Array: + A = self.B_xi.T @ jnp.identity(self.num_strains) @ self.B_xi + alpha = A @ tau + return alpha + else: + if not callable(actuation_mapping_fn): + raise TypeError( + f"actuation_mapping_fn must be a callable, got {type(actuation_mapping_fn).__name__}" + ) + self.actuation_mapping_fn = actuation_mapping_fn + + def classify_segment( + self, + s: Array, + ) -> Tuple[Array, Array]: + """ + Classify the point along the robot to the corresponding segment. + + Args: + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + segment_idx (Array): index of the segment where the point is located + s_segment (Array): point coordinate along the segment in the interval [0, l_segment] + """ + + # Classify the point along the robot to the corresponding segment + segment_idx = jnp.clip(jnp.sum(s > self.L_cum) - 1, 0, self.num_segments - 1) + + # Compute the point coordinate along the segment in the interval [0, l_segment] + s_local = s - self.L_cum[segment_idx] + + return segment_idx, s_local + + def strain( + self, + q: Array, + ) -> Array: + """ + Compute the strain vector from the generalized coordinates. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + xi (Array): strain vector of shape (num_active_strains,) + """ + xi = self.B_xi @ q + self.xi_star + + return xi + + def forward_kinematics( + self, + q: Array, + s: Array, + ) -> Array: + """ + Compute the forward kinematics of the robot at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + g_s (Array): forward kinematics of the robot at point s, shape (4, 4) : + [[ R, p], + [0, 0, 0, 1]] where R is the rotation matrix and p is the position vector. + """ + xi = self.strain(q).reshape(self.num_segments, 6) + + # Compute the point coordinate along the segment in the interval [0, l_segment] + segment_idx, s_local = self.classify_segment(s) + + def body_segment_i(g_base_i, i_segment): + length_i = jnp.where(i_segment == segment_idx, s_local, self.L[i_segment]) + xi_i = xi[i_segment] + + # Magnus expansion + Magnus_i = length_i * xi_i + + g_step = lie.exp_gn_SE3(Magnus_i, eps=self.global_eps) + + g_i = g_base_i @ g_step + + return g_i, g_i + + indices_link = jnp.arange(self.num_segments) + + g_ini = self.g0 # Initial position and orientation of the robot base + + _, g_list = lax.scan(f=body_segment_i, init=g_ini, xs=indices_link) + + # # For debugging purposes, you can uncomment the following line to see the list of transformations + # carry = g_ini + # g_list = [] + # for i_segment in indices_link: + # g_tip_i, g_list_i = body_segment_i(carry, i_segment) + # carry = g_tip_i + # g_list.append(g_list_i) + # g_list = jnp.array(g_list) + + g_s = g_list[segment_idx] + + return g_s + + def _J_local(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + _J_local (Array): Jacobian of the forward kinematics at point s, shape (num_segments, 6, 6) + where each row corresponds to a segment. + """ + xi = self.strain(q).reshape(self.num_segments, 6) + + # Classify the point along the robot to the corresponding segment + segment_idx, s_local = self.classify_segment(s) + + # Initial condition + xi_0 = xi[0] + L_0 = self.L[0] + + Ad_g0_inv_L0 = lie.Adjoint_gi_se3_inv(xi_0, L_0, eps=self.global_eps) + Ad_g0_inv_s = lie.Adjoint_gi_se3_inv(xi_0, s_local, eps=self.global_eps) + + T_g0_L0 = lie.Tangent_gi_se3(xi_0, L_0, eps=self.global_eps) + T_g0_s = lie.Tangent_gi_se3(xi_0, s_local, eps=self.global_eps) + + mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 + mat_0_s = Ad_g0_inv_s @ T_g0_s + + J_0_L0 = jnp.concatenate( + [mat_0_L0[None, :, :], jnp.zeros((self.num_segments - 1, 6, 6))], axis=0 + ) + J_0_s = jnp.concatenate( + [mat_0_s[None, :, :], jnp.zeros((self.num_segments - 1, 6, 6))], axis=0 + ) + + tuple_J_0 = (J_0_L0, J_0_s) + + # Iteration function + def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: + J_prev_Lprev, _ = tuple_J_prev + + xi_i = xi[i] + + Ad_gi_inv_Li = lie.Adjoint_gi_se3_inv(xi_i, self.L[i], eps=self.global_eps) + Ad_gi_inv_s = lie.Adjoint_gi_se3_inv(xi_i, s_local, eps=self.global_eps) + + T_gi_Li = lie.Tangent_gi_se3(xi_i, self.L[i], eps=self.global_eps) + T_gi_s = lie.Tangent_gi_se3(xi_i, s_local, eps=self.global_eps) + + mat_i_Li = Ad_gi_inv_Li @ T_gi_Li + mat_i_s = Ad_gi_inv_s @ T_gi_s + + J_i_s = lax.dynamic_update_slice( + jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), + mat_i_s[jnp.newaxis, ...], + (i, 0, 0), + ) + J_i_Li = lax.dynamic_update_slice( + jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), + mat_i_Li[jnp.newaxis, ...], + (i, 0, 0), + ) + + return (J_i_Li, J_i_s), J_i_s + + indices_links = jnp.arange(1, self.num_segments) + + _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=indices_links) + + # Add the initial condition to the Jacobian array + J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) + + # Extract the Jacobian for the segment that contains the point s + _J_local = lax.dynamic_index_in_dim( + J_array, segment_idx, axis=0, keepdims=False + ) + + return _J_local + + def _final_size_jacobian(self, J_full: Array) -> Array: + """ + Convert the Jacobian or its derivative from the full computation form to the selected strains form. + + Args: + J_full (Array): Full Jacobian of shape (num_segments, 6, 6) + + Returns: + J_selected (Array): Jacobian for the selected strains of shape (6, num_active_strains) + """ + J_final = J_full.transpose(1, 0, 2).reshape(6, self.num_strains) + + return J_final + + def _jacobian_bodyframe_full(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the body frame (for every strains) + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_strains) + """ + _J_local = self._J_local(q, s) + + J_local = self._final_size_jacobian(_J_local) + + return J_local + + def jacobian_bodyframe(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_active_strains) + """ + _J_local = self._J_local(q, s) + + J_local = self._final_size_jacobian(_J_local) @ self.B_xi + + return J_local + + def jacobian_inertialframe(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (6, num_active_strains) + """ + _J_local = self._J_local(q, s) + + g_s = self.forward_kinematics(q, s) + g_s_wo_rot = jnp.block( + [[g_s[:3, :3], jnp.zeros((3, 1))], [jnp.zeros((1, 3)), jnp.ones((1, 1))]] + ) + Adj_g = lie.Adjoint_g_SE3( + g_s_wo_rot + ) # Adjoint representation of the SE(3) transformation + + _J_global = jnp.einsum( + "ij, njk -> nik", + Adj_g, + _J_local, + ) + + J_global = self._final_size_jacobian(_J_global) @ self.B_xi + + return J_global + + def _J_Jd(self, q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + _J_local (Array): Jacobian of the forward kinematics at point s, shape (num_segments, 6, 6) + _J_d_local (Array): Time-derivative of the Jacobian at point s, shape (num_segments, 6, 6) + """ + xi_d = (self.B_xi @ qd).reshape(self.num_segments, 6) + + # Classify the point along the robot to the corresponding segment + segment_idx, _ = self.classify_segment(s) + + _J_local = self._J_local(q, s) + + # ================================= + # Computation of the time-derivative of the Jacobian + + idx_range = jnp.arange(self.num_segments) + J_i = vmap( + lambda i: lax.dynamic_index_in_dim(_J_local, i, axis=0, keepdims=False) + )(idx_range) # shape: (num_segments, 6, 6) + sum_Jj_xi_d_j = compute_weighted_sums( + _J_local, xi_d, self.num_segments + ) # shape: (num_segments, 6) + adjoint_sum = vmap(lie.adjoint_se3)( + sum_Jj_xi_d_j + ) # shape: (num_segments, 6, 6) + + # Compute the time-derivative of the Jacobian + _J_d_local = jnp.einsum( + "ijk, ikl->ijl", adjoint_sum, J_i + ) # shape: (num_segments, 6, 6) + + # Replace the elements of J_d_segment_SE3 for i > segment_idx by null matrices + _J_d_local = jnp.where( + jnp.arange(self.num_segments)[:, None, None] > segment_idx, + jnp.zeros_like(_J_d_local), + _J_d_local, + ) + + return _J_local, _J_d_local + + def _jacobian_and_derivative_bodyframe_full( + self, q: Array, qd: Array, s: Array + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_active_strains) + J_d_local (Array): Time-derivative of the Jacobian at point s in the body frame, shape (6, num_active_strains) + """ + _J_local, _J_d_local = self._J_Jd(q, qd, s) + + J_local = self._final_size_jacobian(_J_local) + J_d_local = self._final_size_jacobian(_J_d_local) + + return J_local, J_d_local + + def jacobian_and_derivative_bodyframe( + self, q: Array, qd: Array, s: Array + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_active_strains) + J_d_local (Array): Time-derivative of the Jacobian at point s in the body frame, shape (6, num_active_strains) + """ + _J_local, _J_d_local = self._J_Jd(q, qd, s) + + J_local = self._final_size_jacobian(_J_local) @ self.B_xi + J_d_local = self._final_size_jacobian(_J_d_local) @ self.B_xi + + return J_local, J_d_local + + def jacobian_and_derivative_inertialframe( + self, q: Array, qd: Array, s: Array + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (6, num_active_strains) + J_d_global (Array): Time-derivative of the Jacobian at point s in the inertial frame, shape (6, num_active_strains) + """ + _J_local, _J_d_local = self._J_Jd(q, qd, s) + + g_s = self.forward_kinematics(q, s) + g_s_wo_rot = jnp.block( + [[g_s[:3, :3], jnp.zeros((3, 1))], [jnp.zeros((1, 3)), jnp.ones((1, 1))]] + ) + Adj_g = lie.Adjoint_g_SE3( + g_s_wo_rot + ) # Adjoint representation of the SE(3) transformation + + _J_global = jnp.einsum( + "ijk, ikl -> ijl", + Adj_g, + _J_local, + ) + _J_d_global = jnp.einsum( + "ijk, ikl -> ijl", + Adj_g, + _J_d_local, + ) + + J_global = self._final_size_jacobian(_J_global) @ self.B_xi + J_d_global = self._final_size_jacobian(_J_d_global) @ self.B_xi + + return J_global, J_d_global + + def jacobian( + self, + q: Array, + s: Array, + ) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_active_strains) + """ + J_local = self.jacobian_bodyframe(q, s) + + return J_local + + def jacobian_and_derivative( + self, + q: Array, + qd: Array, + s: Array, + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (6, num_active_strains) + J_d_local (Array): Time-derivative of the Jacobian at point s in the body frame, shape (6, num_active_strains) + """ + J_local, J_d_local = self.jacobian_and_derivative_bodyframe(q, qd, s) + + return J_local, J_d_local + + # ========================================== + # Useful functions for the system + + def _local_cross_sectional_area(self, i: int) -> Array: + """ + Compute the local cross-sectional area for the i-th segment. + + Args: + i (int): index of the segment + + Returns: + A_i (Array): local cross-sectional area of the i-th segment + """ + A_i = jnp.pi * self.r[i] ** 2 # Cross-sectional area + return A_i + + def _local_mass_matrix(self, i: int) -> Array: + """ + Compute the local mass matrix for the i-th segment. + + Args: + i (int): index of the segment + Returns: + M_i (Array): local mass matrix of shape (6, 6) for the i-th segment + """ + rho_i = self.rho[i] + A_i = self._local_cross_sectional_area(i) # Cross-sectional area + I_i = A_i**2 / (4 * jnp.pi) # Second moment of area + + M_i = rho_i * jnp.diag(jnp.array([I_i, I_i, I_i, A_i, A_i, A_i])) + return M_i + + # =========================================== + # Dynamical matrices computation + + def _inertia_full_matrix( + self, + q: Array, + ) -> Array: + """ + Compute the full inertia matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + B_full (Array): Full inertia matrix of shape (num_strains, num_strains). + """ + + def B_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def B_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + J_j = self._jacobian_bodyframe_full(q, Xs_j) + return Ws_j * J_j.T @ M_i @ J_j + + # B_blocks_i = vmap(B_j)(jnp.arange(self.num_gauss_points)) + + # For debugging purposes, you can uncomment the following line to see the step-by-step computation + B_blocks_i = jnp.stack( + [B_j(j) for j in range(self.num_gauss_points)], axis=0 + ) + + return B_blocks_i + + # B_blocks_tot = vmap(B_i)(jnp.arange(self.num_segments)) + + # For debugging purposes, you can uncomment the following line to see the step-by-step computation + B_blocks_tot = jnp.stack([B_i(i) for i in range(self.num_segments)], axis=0) + + B_full = jnp.sum( + B_blocks_tot, axis=(0, 1) + ) # Sum over segments and Gauss points + + return B_full + + def inertia_matrix( + self, + q: Array, + ) -> Array: + """ + Compute the inertia matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + B (Array): Inertia matrix of shape (num_active_strains, num_active_strains). + """ + B_full = self._inertia_full_matrix(q) + + B = self.B_xi.T @ B_full @ self.B_xi + + return B + + def _coriolis_full_matrix( + self, + q: Array, + qd: Array, + ) -> Array: + """ + Compute the full Coriolis matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + C_full (Array): Full Coriolis matrix of shape (num_strains, num_strains). + """ + + def C_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def C_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + J_j, J_d_j = self._jacobian_and_derivative_bodyframe_full(q, qd, Xs_j) + return Ws_j * ( + J_j.T + @ ( + M_i @ J_d_j + + lie.coadjoint_se3(J_j @ self.B_xi @ qd) @ M_i @ J_j + ) + ) + + C_blocks_i = vmap(C_j)(jnp.arange(self.num_gauss_points)) + + return C_blocks_i + + C_blocks_tot = vmap(C_i)(jnp.arange(self.num_segments)) + + C_full = jnp.sum(C_blocks_tot, axis=(0, 1)) + + return C_full + + def coriolis_matrix( + self, + q: Array, + qd: Array, + ) -> Array: + """ + Compute the Coriolis matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + C (Array): Coriolis matrix of shape (num_active_strains, num_active_strains). + """ + C_full = self._coriolis_full_matrix(q, qd) + + C = self.B_xi.T @ C_full @ self.B_xi + + return C + + def _gravitational_full_vector( + self, + q: Array, + ) -> Array: + """ + Compute the full gravitational vector of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + G (Array): Full gravitational vector of shape (num_strains,). + """ + + def G_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def G_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + Ad_g_inv_j = lie.Adjoint_g_inv_SE3(self.forward_kinematics(q, Xs_j)) + J_j = self._jacobian_bodyframe_full(q, Xs_j) + + return Ws_j * J_j.T @ M_i @ Ad_g_inv_j @ self.g + + G_blocks_segment_i = vmap(G_j)(jnp.arange(self.num_gauss_points)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # G_blocks_segment_i = jnp.stack( + # [G_j(j) for j in range(self.num_gauss_points)], axis=0 + # ) + + return G_blocks_segment_i + + G_blocks_tot = vmap(G_i)(jnp.arange(self.num_segments)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # G_blocks_tot = jnp.stack( + # [G_i(i) for i in range(self.num_segments)], axis=0 + # ) + + G_full = jnp.sum( + G_blocks_tot, axis=(0, 1) + ) # Sum over links and quadrature points + + return G_full + + def gravitational_vector( + self, + q: Array, + ) -> Array: + """ + Compute the gravitational vector of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + G (Array): Gravitational vector of shape (num_active_strains,). + """ + G_full = self._gravitational_full_vector(q) + + G = self.B_xi.T @ G_full + + return G + + def _stiffness_full_matrix( + self, + ) -> Array: + """ + Compute the full stiffness matrix of the robot. + + Returns: + K_full (Array): Full stiffness matrix of shape (num_strains, num_strains). + """ + K_full = self.stiffness_fn(formulate_in_strain_space=True) + + return K_full + + def stiffness_matrix( + self, + ) -> Array: + """ + Compute the stiffness matrix of the robot. + + Returns: + K (Array): Stiffness matrix of shape (num_active_strains, num_active_strains). + """ + K = self.stiffness_fn() + + return K + + def _damping_full_matrix( + self, + ) -> Array: + """ + Compute the full damping matrix of the robot. + + Args: + None + + Returns: + D (Array): Full damping matrix of shape (num_strains, num_strains). + """ + D_full = self.D + + return D_full + + def damping_matrix( + self, + ) -> Array: + """ + Compute the damping matrix of the robot. + + Args: + None + + Returns: + D (Array): Damping matrix of shape (num_active_strains, num_active_strains). + """ + D_full = self._damping_full_matrix() + + D = self.B_xi.T @ D_full @ self.B_xi + + return D + + def actuation_mapping( + self, + q: Array, + actuation_args: Optional[Tuple] = None, + ) -> Array: + """ + Compute the actuation mapping of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function, if any. + + Returns: + alpha (Array): Actuation mapping of shape (num_active_strains, num_active_strains). + """ + alpha = self.actuation_mapping_fn(q, *actuation_args) + + return alpha + + def dynamical_matrices( + self, + q: Array, + qd: Array, + actuation_args: Optional[Tuple] = None, + ) -> Tuple[Array, Array, Array, Array, Array, Array]: + """ + Compute the dynamical matrices of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function, if any. + + Returns: + B (Array): Inertia matrix of shape (num_active_strains, num_active_strains). + C (Array): Coriolis matrix of shape (num_active_strains, num_active_strains). + G (Array): Gravitational vector of shape (num_active_strains,). + K (Array): Stiffness matrix of shape (num_active_strains, num_active_strains). + D (Array): Damping matrix of shape (num_active_strains, num_active_strains). + alpha (Array): Actuation mapping of shape (num_active_strains, num_active_strains). + """ + B = self.inertia_matrix(q) + C = self.coriolis_matrix(q, qd) + G = self.gravitational_vector(q) + K = self.stiffness_matrix() + D = self.damping_matrix() + alpha = self.actuation_mapping(q, actuation_args) + + return B, C, G, K, D, alpha + + def kinetic_energy( + self, + q: Array, + qd: Array, + ) -> float: + """ + Compute the kinetic energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + T (float): Kinetic energy of the robot. + """ + B = self.inertia_matrix(q) + T = 0.5 * qd.T @ B @ qd + + return T + + def elastic_energy( + self, + q: Array, + ) -> float: + """ + Compute the elastic energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U_K (float): Elastic energy of the robot. + """ + K_full = self._stiffness_full_matrix() + U_K = 0.5 * (self.B_xi @ q).T @ K_full @ (self.B_xi @ q) + + return U_K + + def gravitational_energy( + self, + q: Array, + ) -> float: + """ + Compute the gravitational energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U_G (float): Gravitational energy of the robot. + """ + + def U_G_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + rho_i = self.rho[i] + A_i = self._local_cross_sectional_area(i) # Cross-sectional area + + def U_G_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + p_j = jnp.concatenate( + [jnp.zeros(3), self.forward_kinematics(q, Xs_j)[:3, 3]] + ) # Add zeros for the orientation angles + return Ws_j * rho_i * A_i * jnp.dot(p_j, self.g) + + U_G_blocks_segment_i = vmap(U_G_j)(jnp.arange(self.num_gauss_points)) + + return U_G_blocks_segment_i + + U_G_blocks_tot = vmap(U_G_i)(jnp.arange(self.num_segments)) + + U_G = jnp.sum(U_G_blocks_tot, axis=(0, 1)) # Sum over segments and Gauss points + + return U_G + + def potential_energy( + self, + q: Array, + ) -> float: + """ + Compute the potential energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U (float): Potential energy of the robot. + """ + U_K = self.elastic_energy(q) + U_G = self.gravitational_energy(q) + + return U_K + U_G + + def total_energy( + self, + q: Array, + qd: Array, + ) -> float: + """ + Compute the total energy of the robot, which is the sum of kinetic and potential energy. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + E (float): Total energy of the robot. + """ + T = self.kinetic_energy(q, qd) + U = self.potential_energy(q) + E = T + U + return E + + def operational_space_dynamical_matrices( + self, + q: Array, + qd: Array, + s: Array, + operational_space_selector: Tuple = (True, True, True, True, True, True), + ) -> Tuple[Array, Array, Array, Array, Array]: + """ + Compute the operational space dynamical matrices for the robot at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + operational_space_selector (Tuple): Selector for the operational space dimensions. + Default is (True, True, True, True, True, True) for all dimensions. + + Returns: + Lambda (Array): Inertia matrix in the operational space, shape (num_operational_space_dims, num_operational_space_dims). + mu (Array): Coriolis and centrifugal matrix in the operational space, shape (num_operational_space_dims,). + J (Array): Jacobian of the forward kinematics at point s in the body frame, shape (num_operational_space_dims, num_active_strains). + J_d (Array): Time-derivative of the Jacobian at point s in the body frame, shape (num_operational_space_dims, num_active_strains). + JB_pinv (Array): Dynamically-consistent pseudo-inverse of the Jacobian, shape (num_active_strains, num_operational_space_dims). + """ + # classify the point along the robot to the corresponding segment + _, s_local = self.classify_segment(s) + + # make operational_space_selector a boolean array + operational_space_selector = jnp.array(operational_space_selector, dtype=bool) + + # Jacobian and its time-derivative + J, J_d = self.jacobian_and_derivative_bodyframe(q, qd, s_local) + + J = J[operational_space_selector, :] + J_d = J_d[operational_space_selector, :] + + # inverse of the inertia matrix in the configuration space + B = self.inertia_matrix(q) + B_inv = jnp.linalg.inv(B) + C = self.coriolis_matrix(q, qd) + + Lambda = jnp.linalg.inv( + J @ B_inv @ J.T + ) # inertia matrix in the operational space + mu = Lambda @ ( + J @ B_inv @ C - J_d + ) # coriolis and centrifugal matrix in the operational space + + JB_pinv = ( + B_inv @ J.T @ Lambda + ) # dynamically-consistent pseudo-inverse of the Jacobian + + return Lambda, mu, J, J_d, JB_pinv + + @eqx.filter_jit + def forward_dynamics( + self, + t: float, + y: Array, + actuation_args: Optional[Tuple] = None, + ) -> Array: + """ + Forward dynamics function. + + Args: + t (float): Current time. + y (Array): State vector containing configuration and velocity. + Shape is (2 * num_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function. + Default is None. + Returns: + y_d: Time derivative of the state vector. + """ + + q, qd = jnp.split( + y, 2 + ) # Split the state vector into configuration and velocity + + B, C, G, K, D, alpha = self.dynamical_matrices(q, qd, actuation_args) + + B_inv = jnp.linalg.inv(B) # Inverse of the inertia matrix + qdd = B_inv @ (-C @ qd - G - K @ q - D @ qd + alpha) # Compute the acceleration + + y_d = jnp.concatenate([qd, qdd]) + + return y_d + + def resolve_upon_time( + self, + q0: Array, + qd0: Array, + actuation_args: Optional[Tuple] = None, + t0: Optional[float] = 0.0, + t1: Optional[float] = 10.0, + dt: Optional[float] = 1e-4, + skip_steps: Optional[int] = 0, + solver: Optional[AbstractSolver] = Tsit5(), + stepsize_controller: Optional[PIDController] = ConstantStepSize(), + max_steps: Optional[int] = None, + ) -> Tuple[Array, Array, Array]: + """ + Resolve the system dynamics over time using Diffrax. + + Args: + q0 (Array): Initial configuration (strains). + qd0 (Array): Initial velocity (strains). + actuation_args (Tuple, optional): Additional arguments for the actuation function. + Default is None (no actuation). + t0 (float, optionnal): Initial time. + Default is 0.0. + t1 (float, optionnal): Final time. + Default is 10.0. + dt (float, optionnal): Time step for the solver. + Default is 1e-4. + skip_steps (int, optionnal): Number of steps to skip in the output. + This allows to reduce the number of saved time points. + Default is 0. + solver (AbstractSolver, optional): Solver to use for the ODE integration. + Default is Tsit5() (Runge-Kutta 5(4) method). + stepsize_controller (PIDController, optional): Stepsize controller for the solver. + Default is ConstantStepSize(). + max_steps (int, optional): Maximum number of steps for the solver. + Default is None (no limit). + + Returns: + ts (Array): Time points at which the solution is saved. + qs (Array): Configuration (strains) at the saved time points. + qds (Array): Velocity (strains) at the saved time points. + """ + y0 = jnp.concatenate([q0, qd0]) # Initial state vector + + term = ODETerm(self.forward_dynamics) + + t = jnp.arange(t0, t1, dt) # Time points for the solution + saveat = SaveAt(ts=t[::skip_steps]) # Save at specified time points + + sol = diffeqsolve( + terms=term, + solver=solver, + t0=t[0], + t1=t[-1], + dt0=dt, + y0=y0, + args=actuation_args, + saveat=saveat, + stepsize_controller=stepsize_controller, + max_steps=max_steps, + ) + + ts = sol.ts + # Extract the configuration and velocity from the solution + y_out = sol.ys + qs, qds = jnp.split(y_out, 2, axis=1) + + return ts, qs, qds diff --git a/src/jsrm/systems/planar_pcs.py b/src/jsrm/systems/planar_pcs.py new file mode 100644 index 0000000..fe45390 --- /dev/null +++ b/src/jsrm/systems/planar_pcs.py @@ -0,0 +1,1477 @@ +from jax import Array, lax, vmap +from jax import numpy as jnp + +from typing import Callable, Dict, Tuple, Optional + +import equinox as eqx + +from .utils import ( + compute_strain_basis, + compute_planar_stiffness_matrix, + gauss_quadrature, + scale_gaussian_quadrature, +) +from jsrm.math_utils import ( + blk_diag, + compute_weighted_sums, +) +import jsrm.utils.lie_algebra as lie + +from diffrax import ( + diffeqsolve, + ODETerm, + SaveAt, + Tsit5, + PIDController, + ConstantStepSize, + AbstractSolver, +) + + +class PlanarPCS(eqx.Module): + """ + Planar Piecewise Constant Strain (PCS) model for 2D soft continuum robots. + + This class implements the geometric and dynamic modeling of a 2D soft robot + using the Cosserat rod theory and piecewise constant strain assumption. + It supports computation of forward kinematics, Jacobians, dynamical matrices. + + Attributes: + ---------- + num_segments : int + Number of segments (constant strain sections) along the robot. + th0 : Array + Initial orientation angle of the robot in radians. + g : Array + Gravitational acceleration vector (embedded in a 3D vector). + [0, g_x, g_y] + L, r, E, G, rho, D : Array + Physical properties of each segment (length, radius, elastic/shear modulus, etc.). + num_active_strains : int + Number of active strain components (based on strain_selector). + num_strains : int + Total number of strain components (6 * num_segments). + B_xi : Array + Basis matrix for projecting active strains. + xi_star : Array + Rest strain (reference configuration) of the robot. + num_gauss_points : int + Number of points used for numerical integration. + Corresponds to the order of Gauss-Legendre quadrature + 2 (for the endpoints). + Xs, Ws : Array + Gauss-Legendre quadrature nodes and weights for numerical integration. + stiffness_fn : Callable + Function to compute the full stiffness matrix. + actuation_mapping_fn : Callable + Function to map actuation torques into strain space. + + Methods: + ------- + strain(q: Array) -> Array: + Computes the strain vector from generalized coordinates. + forward_kinematics(q: Array, s: Array) -> Array: + Computes the forward kinematics at a point s along the robot. + jacobian(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (global frame). + jacobian_and_derivative(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (global frame). + dynamical_matrix(q: Array, qd: Array, actuation_args: Tuple[Array]) -> Array: + Computes the dynamical matrix for the system. + resolve_upon_time(q0: Array, qd0: Array, actuation_args: Tuple[Array], t0: float, t1: float, dt: float, skip_steps: int, solver: AbstractSolver, max_steps: Optional[int] = None) -> Tuple[Array, Array, Array]: + Simulates the robot dynamics over time using the specified solver. + forward_dynamics(t: float, y: Array, actuation_args: Optional[Tuple]) -> Array: + Computes the forward dynamics of the system at a given time t. + + kinetic_energy(q: Array, qd: Array) -> float: + Computes the kinetic energy of the system. + elastic_energy(q: Array) -> float: + Computes the elastic potential energy of the system. + gravitational_energy(q: Array) -> float: + Computes the gravitational potential energy of the system. + potential_energy(q: Array) -> float: + Computes the total potential energy (elastic + gravitational) of the system. + total_energy(q: Array, qd: Array) -> float: + Computes the total energy (kinetic + potential) of the system. + + operational_space_dynamical_matrices(q: Array, qd: Array, s: Array, operational_space_selector: Optional[Tuple]) -> Tuple[Array, Array, Array, Array, Array]: + Computes the operational space dynamical matrices that define the operational space dynamics of the systems. Can be used, for example, for operational space control. + + inertia_matrix(q: Array) -> Array: + Computes the inertia matrix of the system. + coriolis_matrix(q: Array, qd: Array) -> Array: + Computes the Coriolis matrix of the system. + gravitational_vector(q: Array) -> Array: + Computes the gravitational force vector acting on the system. + stiffness_matrix() -> Array: + Computes the stiffness matrix of the system. + damping_matrix() -> Array: + Computes the damping matrix of the system. + actuation_mapping(q: Array, actuation_args: Tuple[Array]) -> Array: + Computes the actuation mapping of the system. + + classify_segment(s: Array) -> Tuple[Array, Array]: + Classifies a point along the robot to its corresponding segment and local coordinate. + jacobian_bodyframe(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (body frame). + jacobian_inertialframe(q: Array, s: Array) -> Array: + Computes the Jacobian of the forward kinematics at a point s (inertial frame). + jacobian_and_derivative_bodyframe(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (body frame). + jacobian_and_derivative_inertialframe(q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + Computes the Jacobian and its time derivative at a point s (inertial frame). + + Notes: + ----- + - The strain vector is composed of 3 components per segment: + [kappa_z, sigma_x, sigma_y]. + By default, the rod is assumed to be straight and aligned with the y-axis, + so the rest strain is set to [0, 0, 1]. + Thus: - kappa_z corresponds to bending around the z-axis, + - sigma_x corresponds to shear along the x-axis, + - sigma_y corresponds to axial strain along the y-axis. + + """ + + # Robot parameters + th0: Array # Initial orientation angle [rad] + g: Array # Gravitational acceleration vector + + L: Array # Length of the segments + L_cum: Array # Cumulative length of the segments + r: Array # Radius of the segments + rho: Array + E: Array # Young's modulus of the segments + G: Array # Shear modulus of the segments + D: Array # Damping coefficient of the segments + + global_eps: float = jnp.finfo(jnp.float64).eps + + stiffness_fn: Callable = eqx.static_field() + actuation_mapping_fn: Callable = eqx.static_field() + + num_segments: int = eqx.static_field() + num_gauss_points: int = eqx.static_field() # + num_strains: int = eqx.static_field() # Number of strains (3 * num_segments) + + xi_star: Array # Rest configuration strain + B_xi: Array # Strain basis matrix + num_active_strains: Array # Number of selected strains + + Xs: Array # Gauss nodes + Ws: Array # Gauss weights + + def __init__( + self, + num_segments: int, + params: Dict[str, Array], + order_gauss: int = 5, + strain_selector: Optional[Array] = None, + xi_star: Optional[Array] = None, + stiffness_fn: Optional[Callable] = None, + actuation_mapping_fn: Optional[Callable] = None, + ) -> "PlanarPCS": + """ + Initialize the PlanarPCS class. + + Args: + num_segments (int): + Number of segments in the robot. + params (Dict[str, Array]): + Dictionary containing the robot parameters: + - "th0": Initial orientation angle [rad] + - "l": Length of each segment [m] + - "r": Radius of each segment [m] + - "rho": Density of each segment [kg/m^3] + - "g": Gravitational acceleration vector [m/s^2] + - "E": Elastic modulus of each segment [Pa] + - "G": Shear modulus of each segment [Pa] + order_gauss (int, optional): + Order of the Gauss-Legendre quadrature for integration over each segment. + Defaults to 5. + strain_selector (Optional[Array], optional): + Boolean array of shape (3 * num_segments,) specifying which strain components are active. + Defaults to all strains active (i.e. all True). + xi_star (Optional[Array], optional): + Rest strain of shape (3 * num_segments,). + Defaults to 0.0 for bending and shear strains, and 1.0 for axial strain (along local y-axis). + stiffness_fn (Optional[Callable], optional): + Function to compute the stiffness matrix. + Defaults to : + l_i * diag( E_i * Ib_i, # bending Z + 4/3 * A_i * G_i, # shear X + A_i * E_i, # axial Y) + actuation_mapping_fn (Optional[Callable], optional): + Function to compute the actuation mapping. + This function needs to take as input: + - q: generalized coordinates of shape (num_active_strains,) + - actuation_args: tuple containing the actuation parameters (e.g. torques (tau,)). + Defaults to identity linear mapping. actuation_args = (tau,) + + """ + # Number of segments + if not isinstance(num_segments, int): + raise TypeError( + f"num_segments must be an integer, got {type(num_segments).__name__}" + ) + if num_segments < 1: + raise ValueError(f"num_segments must be at least 1, got {num_segments}") + self.num_segments = num_segments + + num_strains = 3 * num_segments + self.num_strains = num_strains + + # ================================================================ + # Robot parameters + + # Initial orientation angle + try: + th0 = params["th0"] + except KeyError: + raise KeyError("Parameter 'th0' is required in params dictionary.") + if not (isinstance(th0, (float, int, jnp.ndarray))): + raise TypeError( + f"th0 must be a float, int, or an array, got {type(th0).__name__}" + ) + th0 = jnp.asarray(th0, dtype=jnp.float64) + self.th0 = th0 + + # Gravitational acceleration vector + try: + g = params["g"] + except KeyError: + raise KeyError("Parameter 'g' is required in params dictionary.") + if not (isinstance(g, (list, jnp.ndarray))): + raise TypeError(f"g must be a list or an array, got {type(g).__name__}") + g = jnp.asarray(g, dtype=jnp.float64) + if g.size != 2: + raise ValueError(f"g must be a vector of shape (2,), got {g.size}") + self.g = jnp.concatenate( + [jnp.zeros(1), g] + ) # Add a zero for the orientation angle + + # Lengths of the segments + try: + L = params["l"] + except KeyError: + raise KeyError("Parameter 'l' is required in params dictionary.") + if not (isinstance(L, (list, jnp.ndarray))): + raise TypeError(f"l must be a list or an array, got {type(L).__name__}") + L = jnp.asarray(L, dtype=jnp.float64) + if L.shape != (num_segments,): + raise ValueError(f"l must have shape ({num_segments},), got {L.shape}") + self.L = L + + L_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), self.L])) + self.L_cum = L_cum + + # Radius of the segments + try: + r = params["r"] + except KeyError: + raise KeyError("Parameter 'r' is required in params dictionary.") + if not (isinstance(r, (list, jnp.ndarray))): + raise TypeError(f"r must be a list or an array, got {type(r).__name__}") + r = jnp.asarray(r, dtype=jnp.float64) + if r.shape != (num_segments,): + raise ValueError(f"r must have shape ({num_segments},), got {r.shape}") + self.r = r + + # Densities of the segments + try: + rho = params["rho"] + except KeyError: + raise KeyError("Parameter 'rho' is required in params dictionary.") + if not (isinstance(rho, (list, jnp.ndarray))): + raise TypeError(f"rho must be a list or an array, got {type(rho).__name__}") + rho = jnp.asarray(rho, dtype=jnp.float64) + if rho.shape != (num_segments,): + raise ValueError(f"rho must have shape ({num_segments},), got {rho.shape}") + self.rho = rho + + # Elastic modulus of the segments + try: + E = params["E"] + except KeyError: + raise KeyError("Parameter 'E' is required in params dictionary.") + if not (isinstance(E, (list, jnp.ndarray))): + raise TypeError(f"E must be a list or an array, got {type(E).__name__}") + E = jnp.asarray(E, dtype=jnp.float64) + if E.shape != (num_segments,): + raise ValueError(f"E must have shape ({num_segments},), got {E.shape}") + self.E = E + + # Shear modulus of the segments + try: + G = params["G"] + except KeyError: + raise KeyError("Parameter 'G' is required in params dictionary.") + if not (isinstance(G, (list, jnp.ndarray))): + raise TypeError(f"G must be a list or an array, got {type(G).__name__}") + G = jnp.asarray(G, dtype=jnp.float64) + if G.shape != (num_segments,): + raise ValueError(f"G must have shape ({num_segments},), got {G.shape}") + self.G = G + + # Damping matrix of the robot + try: + D = params["D"] + except KeyError: + raise KeyError("Parameter 'D' is required in params dictionary.") + if not (isinstance(D, (list, jnp.ndarray))): + raise TypeError(f"D must be a list or an array, got {type(D).__name__}") + D = jnp.asarray(D, dtype=jnp.float64) + expected_D_shape = (num_strains, num_strains) + if D.shape != expected_D_shape: + raise ValueError(f"D must have shape {expected_D_shape}, got {D.shape}") + self.D = D + + # ================================================================ + # Order of Gauss-Legendre quadrature + if not isinstance(order_gauss, int): + raise TypeError( + f"order_gauss must be an integer, got {type(order_gauss).__name__}" + ) + if order_gauss < 1: + raise ValueError(f"param_integration must be at least 1, got {order_gauss}") + Xs, Ws, num_gauss_points = gauss_quadrature(order_gauss, a=0.0, b=1.0) + self.Xs = Xs + self.Ws = Ws + self.num_gauss_points = num_gauss_points + + # ================================================================ + # Strain basis matrix + if strain_selector is None: + strain_selector = jnp.ones(num_strains, dtype=bool) + else: + if not isinstance(strain_selector, (list, jnp.ndarray)): + raise TypeError( + f"strain_selector must be a list or an array, got {type(strain_selector).__name__}" + ) + strain_selector = jnp.asarray(strain_selector) + if not jnp.issubdtype(strain_selector.dtype, jnp.bool_): + raise TypeError( + f"strain_selector must be a boolean array, got {strain_selector.dtype}" + ) + if strain_selector.size != num_strains: + raise ValueError( + f"strain_selector must have {num_strains} elements, got {strain_selector.size}" + ) + strain_selector = strain_selector.reshape(num_strains) + self.B_xi = compute_strain_basis(strain_selector) + + self.num_active_strains = jnp.sum(strain_selector) + + # Rest configuration strain + if xi_star is None: + xi_star = jnp.tile( + jnp.array([0.0, 0.0, 1.0], dtype=jnp.float64), (num_segments, 1) + ).reshape(num_strains) + else: + if not isinstance(xi_star, (list, jnp.ndarray)): + raise TypeError( + f"xi_star must be a list or an array, got {type(xi_star).__name__}" + ) + xi_star = jnp.asarray(xi_star) + if xi_star.size != num_strains: + raise ValueError( + f"xi_star must have {num_strains} elements, got {xi_star.size}" + ) + xi_star = xi_star.reshape(num_strains) + self.xi_star = xi_star + + # Stiffness function + if stiffness_fn is None: + compute_stiffness_matrix_for_all_segments_fn = vmap( + compute_planar_stiffness_matrix + ) + + def stiffness_fn( + formulate_in_strain_space: bool = False, + ) -> Array: + L = self.L + r = self.r + E = self.E + G = self.G + + # cross-sectional area and second moment of area + A = jnp.pi * r**2 + Ib = A**2 / (4 * jnp.pi) + + # stiffness matrix of shape (num_segments, 3, 3) + S_sms = compute_stiffness_matrix_for_all_segments_fn(L, A, Ib, E, G) + # we define the elastic matrix of shape (num_strains, num_strains) as K(xi) = K @ xi where K is equal to + S = blk_diag(S_sms) + + if not formulate_in_strain_space: + S = self.B_xi.T @ S @ self.B_xi + + return S + else: + if not callable(stiffness_fn): + raise TypeError( + f"stiffness_fn must be a callable, got {type(stiffness_fn).__name__}" + ) + self.stiffness_fn = stiffness_fn + + # Actuation mapping function + if actuation_mapping_fn is None: + + def actuation_mapping_fn(q: Array, tau: Array) -> Array: + A = self.B_xi.T @ jnp.identity(self.num_strains) @ self.B_xi + alpha = A @ tau + return alpha + else: + if not callable(actuation_mapping_fn): + raise TypeError( + f"actuation_mapping_fn must be a callable, got {type(actuation_mapping_fn).__name__}" + ) + self.actuation_mapping_fn = actuation_mapping_fn + + def classify_segment( + self, + s: Array, + ) -> Tuple[Array, Array]: + """ + Classify the point along the robot to the corresponding segment. + + Args: + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + segment_idx (Array): index of the segment where the point is located + s_segment (Array): point coordinate along the segment in the interval [0, l_segment] + """ + + # Classify the point along the robot to the corresponding segment + segment_idx = jnp.clip(jnp.sum(s > self.L_cum) - 1, 0, self.num_segments - 1) + + # Compute the point coordinate along the segment in the interval [0, l_segment] + s_local = s - self.L_cum[segment_idx] + + return segment_idx, s_local + + def strain( + self, + q: Array, + ) -> Array: + """ + Compute the strain vector from the generalized coordinates. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + xi (Array): strain vector of shape (num_active_strains,) + """ + xi = self.B_xi @ q + self.xi_star + + return xi + + def chi( + self, + xi: Array, + s: Array, + ) -> Array: + """ + Compute the forward kinematics of the robot. + + Args: + xi (Array): strain vector of shape (3*num_segments,) where each row corresponds to a segment + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + chi_s (Array): forward kinematics of the robot at point s, shape (3,) : [theta, x, y] + """ + xi = xi.reshape(self.num_segments, 3) + + segment_idx, s_local = self.classify_segment(s) + + chi_0 = jnp.concatenate( + [self.th0[None], jnp.zeros(2)] + ) # Initial configuration [theta, x, y] + + # Iteration function + def chi_i(chi_prev: Array, i: int) -> Array: + th_prev = chi_prev[0] + p_prev = chi_prev[1:] + + kappa_i = xi[i, 0] + sigmas_i = xi[i, 1:] + + l_i = jnp.where(i == segment_idx, s_local, self.L[i]) + + th = th_prev + kappa_i * l_i + + int_cos_th = jnp.where( + jnp.abs(kappa_i) < self.global_eps, + l_i * jnp.cos(th_prev), + (jnp.sin(th) - jnp.sin(th_prev)) / kappa_i, + ) + int_sin_th = jnp.where( + jnp.abs(kappa_i) < self.global_eps, + l_i * jnp.sin(th_prev), + -(jnp.cos(th) - jnp.cos(th_prev)) / kappa_i, + ) + + R = jnp.stack( + [ + jnp.stack([int_cos_th, -int_sin_th]), + jnp.stack([int_sin_th, int_cos_th]), + ] + ) + + p = p_prev + R @ sigmas_i + + chi = jnp.concatenate([th[None], p]) + + return chi, chi + + _, chi_list = lax.scan(f=chi_i, init=chi_0, xs=jnp.arange(self.num_segments)) + + chi_s = chi_list[segment_idx] + + return chi_s + + def forward_kinematics( + self, + q: Array, + s: Array, + ) -> Array: + """ + Compute the forward kinematics of the robot at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + chi (Array): forward kinematics of the robot at point s, shape (3,) : [theta, x, y] + """ + xi = self.strain(q) + + chi = self.chi(xi, s) + + return chi + + def _J_local(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + _J_local (Array): Jacobian of the forward kinematics at point s, shape (num_segments, 3, 3) + where each row corresponds to a segment. + """ + xi = self.strain(q).reshape(self.num_segments, 3) + + # Classify the point along the robot to the corresponding segment + segment_idx, s_local = self.classify_segment(s) + + # Initial condition + xi_0 = xi[0] + L_0 = self.L[0] + + Ad_g0_inv_L0 = lie.Adjoint_gi_se2_inv(xi_0, L_0, eps=self.global_eps) + Ad_g0_inv_s = lie.Adjoint_gi_se2_inv(xi_0, s_local, eps=self.global_eps) + + T_g0_L0 = lie.Tangent_gi_se2(xi_0, L_0, eps=self.global_eps) + T_g0_s = lie.Tangent_gi_se2(xi_0, s_local, eps=self.global_eps) + + mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 + mat_0_s = Ad_g0_inv_s @ T_g0_s + + J_0_L0 = jnp.concatenate( + [mat_0_L0[None, :, :], jnp.zeros((self.num_segments - 1, 3, 3))], axis=0 + ) + J_0_s = jnp.concatenate( + [mat_0_s[None, :, :], jnp.zeros((self.num_segments - 1, 3, 3))], axis=0 + ) + + tuple_J_0 = (J_0_L0, J_0_s) + + # Iteration function + def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: + J_prev_Lprev, _ = tuple_J_prev + + xi_i = xi[i] + + Ad_gi_inv_Li = lie.Adjoint_gi_se2_inv(xi_i, self.L[i], eps=self.global_eps) + Ad_gi_inv_s = lie.Adjoint_gi_se2_inv(xi_i, s_local, eps=self.global_eps) + + T_gi_Li = lie.Tangent_gi_se2(xi_i, self.L[i], eps=self.global_eps) + T_gi_s = lie.Tangent_gi_se2(xi_i, s_local, eps=self.global_eps) + + mat_i_Li = Ad_gi_inv_Li @ T_gi_Li + mat_i_s = Ad_gi_inv_s @ T_gi_s + + J_i_s = lax.dynamic_update_slice( + jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), + mat_i_s[jnp.newaxis, ...], + (i, 0, 0), + ) + J_i_Li = lax.dynamic_update_slice( + jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), + mat_i_Li[jnp.newaxis, ...], + (i, 0, 0), + ) + + return (J_i_Li, J_i_s), J_i_s + + indices_links = jnp.arange(1, self.num_segments) + + _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=indices_links) + + # Add the initial condition to the Jacobian array + J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) + + # Extract the Jacobian for the segment that contains the point s + _J_local = lax.dynamic_index_in_dim( + J_array, segment_idx, axis=0, keepdims=False + ) + + return _J_local + + def _final_size_jacobian(self, J_full: Array) -> Array: + """ + Convert the Jacobian or its derivative from the full computation form to the selected strains form. + + Args: + J_full (Array): Full Jacobian of shape (num_segments, 3, 3) + + Returns: + J_selected (Array): Jacobian for the selected strains of shape (3, num_strains) + """ + J_final = J_full.transpose(1, 0, 2).reshape(3, self.num_strains) + + return J_final + + def _jacobian_bodyframe_full(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the body frame (for every strains) + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (3, num_strains) + """ + _J_local = self._J_local(q, s) + + J_local = self._final_size_jacobian(_J_local) + + return J_local + + def jacobian_bodyframe(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (3, num_active_strains) + """ + _J_local = self._J_local(q, s) + + J_local = self._final_size_jacobian(_J_local) @ self.B_xi + + return J_local + + def jacobian_inertialframe(self, q: Array, s: Array) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (3, num_active_strains) + """ + _J_local = self._J_local(q, s) + + chi = self.forward_kinematics(q, s) + theta = chi[0] + g_i = lie.exp_SE2( + jnp.stack([theta, 0.0, 0.0]) + ) # SE(2) transformation at point s + Adj_gi = lie.Adjoint_g_SE2( + g_i + ) # Adjoint representation of the SE(2) transformation + + _J_global = jnp.einsum( + "ij, njk -> nik", + Adj_gi, + _J_local, + ) + + J_global = self._final_size_jacobian(_J_global) @ self.B_xi + + return J_global + + def _J_Jd(self, q: Array, qd: Array, s: Array) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + _J_local (Array): Jacobian of the forward kinematics at point s, shape (num_segments, 3, 3) + _J_d_local (Array): Time-derivative of the Jacobian at point s, shape (num_segments, 3, 3) + """ + xi_d = (self.B_xi @ qd).reshape(self.num_segments, 3) + + # Classify the point along the robot to the corresponding segment + segment_idx, _ = self.classify_segment(s) + + _J_local = self._J_local(q, s) + + # ================================= + # Computation of the time-derivative of the Jacobian + + idx_range = jnp.arange(self.num_segments) + J_i = vmap( + lambda i: lax.dynamic_index_in_dim(_J_local, i, axis=0, keepdims=False) + )(idx_range) # shape: (num_segments, 3, 3) + sum_Jj_xi_d_j = compute_weighted_sums( + _J_local, xi_d, self.num_segments + ) # shape: (num_segments, 3) + adjoint_sum = vmap(lie.adjoint_se2)( + sum_Jj_xi_d_j + ) # shape: (num_segments, 3, 3) + + # Compute the time-derivative of the Jacobian + _J_d_local = jnp.einsum( + "ijk, ikl->ijl", adjoint_sum, J_i + ) # shape: (num_segments, 3, 3) + + # Replace the elements of J_d_segment_SE2 for i > segment_idx by null matrices + _J_d_local = jnp.where( + jnp.arange(self.num_segments)[:, None, None] > segment_idx, + jnp.zeros_like(_J_d_local), + _J_d_local, + ) + + return _J_local, _J_d_local + + def jacobian_and_derivative_bodyframe( + self, q: Array, qd: Array, s: Array + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the body frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_local (Array): Jacobian of the forward kinematics at point s in the body frame, shape (3, num_active_strains) + J_d_local (Array): Time-derivative of the Jacobian at point s in the body frame, shape (3, num_active_strains) + """ + _J_local, _J_d_local = self._J_Jd(q, qd, s) + + J_local = self._final_size_jacobian(_J_local) @ self.B_xi + J_d_local = self._final_size_jacobian(_J_d_local) @ self.B_xi + + return J_local, J_d_local + + def jacobian_and_derivative_inertialframe( + self, q: Array, qd: Array, s: Array + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (3, num_active_strains) + J_d_global (Array): Time-derivative of the Jacobian at point s in the inertial frame, shape (3, num_active_strains) + """ + _J_local, _J_d = self._J_Jd(q, qd, s) + + chi = self.forward_kinematics(q, s) + theta = chi[0] + g_i = lie.exp_SE2( + jnp.stack([theta, 0.0, 0.0]) + ) # SE(2) transformation at point s + Adj_gi = lie.Adjoint_g_SE2(g_i) + + _J_global = jnp.einsum( + "ijk, ikl -> ijl", + Adj_gi, + _J_local, + ) + _J_d = jnp.einsum( + "ijk, ikl -> ijl", + Adj_gi, + _J_d, + ) + + J_global = self._final_size_jacobian(_J_global) @ self.B_xi + J_d_global = self._final_size_jacobian(_J_d) @ self.B_xi + + return J_global, J_d_global + + def jacobian( + self, + q: Array, + s: Array, + ) -> Array: + """ + Compute the Jacobian of the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (3, num_active_strains) + """ + J_global = self.jacobian_inertialframe(q, s) + + return J_global + + def jacobian_and_derivative( + self, + q: Array, + qd: Array, + s: Array, + ) -> Tuple[Array, Array]: + """ + Compute the Jacobian and its time-derivative for the forward kinematics at a point s along the robot in the inertial frame. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + + Returns: + J_global (Array): Jacobian of the forward kinematics at point s in the inertial frame, shape (3, num_active_strains) + J_d_global (Array): Time-derivative of the Jacobian at point s in the inertial frame, shape (3, num_active_strains) + """ + J_global, J_d_global = self.jacobian_and_derivative_inertialframe(q, qd, s) + + return J_global, J_d_global + + # ========================================== + # Useful functions for the system + + def _local_cross_sectional_area(self, i: int) -> Array: + """ + Compute the local cross-sectional area for the i-th segment. + + Args: + i (int): index of the segment + + Returns: + A_i (Array): local cross-sectional area of the i-th segment + """ + A_i = jnp.pi * self.r[i] ** 2 # Cross-sectional area + return A_i + + def _local_mass_matrix(self, i: int) -> Array: + """ + Compute the local mass matrix for the i-th segment. + + Args: + i (int): index of the segment + Returns: + M_i (Array): local mass matrix of shape (3, 3) for the i-th segment + """ + rho_i = self.rho[i] + A_i = self._local_cross_sectional_area(i) # Cross-sectional area + I_i = A_i**2 / (4 * jnp.pi) # Second moment of area + + M_i = rho_i * jnp.diag(jnp.array([I_i, A_i, A_i])) + return M_i + + # =========================================== + # Dynamical matrices computation + + def _inertia_full_matrix( + self, + q: Array, + ) -> Array: + """ + Compute the full inertia matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + B_full (Array): Full inertia matrix of shape (num_strains, num_strains). + """ + + def B_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def B_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + J_j = self._jacobian_bodyframe_full(q, Xs_j) + return Ws_j * J_j.T @ M_i @ J_j + + B_blocks_i = vmap(B_j)(jnp.arange(self.num_gauss_points)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # B_blocks_i = jnp.stack([B_j(j) for j in range(self.num_gauss_points)], axis=0) + + return B_blocks_i + + B_blocks_tot = vmap(B_i)(jnp.arange(self.num_segments)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # B_blocks_tot = jnp.stack([B_i(i) for i in range(self.num_segments)], axis=0) + + B_full = jnp.sum( + B_blocks_tot, axis=(0, 1) + ) # Sum over segments and Gauss points + + return B_full + + def inertia_matrix( + self, + q: Array, + ) -> Array: + """ + Compute the inertia matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + B (Array): Inertia matrix of shape (num_active_strains, num_active_strains). + """ + B_full = self._inertia_full_matrix(q) + + B = self.B_xi.T @ B_full @ self.B_xi + + return B + + def _coriolis_full_matrix( + self, + q: Array, + qd: Array, + ) -> Array: + """ + Compute the full Coriolis matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + C_full (Array): Full Coriolis matrix of shape (num_strains, num_strains). + """ + + def C_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def C_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + J_j, J_d_j = self.jacobian_and_derivative_bodyframe(q, qd, Xs_j) + return Ws_j * ( + J_j.T @ (M_i @ J_d_j + lie.coadjoint_se2(J_j @ qd) @ M_i @ J_j) + ) + + C_blocks_i = vmap(C_j)(jnp.arange(self.num_gauss_points)) + + return C_blocks_i + + C_blocks_tot = vmap(C_i)(jnp.arange(self.num_segments)) + + C_full = jnp.sum(C_blocks_tot, axis=(0, 1)) + + return C_full + + def coriolis_matrix( + self, + q: Array, + qd: Array, + ) -> Array: + """ + Compute the Coriolis matrix of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + C (Array): Coriolis matrix of shape (num_active_strains, num_active_strains). + """ + C_full = self._coriolis_full_matrix(q, qd) + + C = self.B_xi.T @ C_full @ self.B_xi + + return C + + def _gravitational_full_vector( + self, + q: Array, + ) -> Array: + """ + Compute the full gravitational vector of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + G (Array): Full gravitational vector of shape (num_strains,). + """ + + def G_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + M_i = self._local_mass_matrix(i) + + def G_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + Ad_g_inv_j = lie.Adjoint_g_inv_SE2( + lie.exp_SE2(self.forward_kinematics(q, Xs_j)) + ) + J_j = self._jacobian_bodyframe_full(q, Xs_j) + + return Ws_j * J_j.T @ M_i @ Ad_g_inv_j @ self.g + + G_blocks_segment_i = vmap(G_j)(jnp.arange(self.num_gauss_points)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # G_blocks_segment_i = jnp.stack( + # [G_j(j) for j in range(self.num_gauss_points)], axis=0 + # ) + + return G_blocks_segment_i + + G_blocks_tot = vmap(G_i)(jnp.arange(self.num_segments)) + + # # For debugging purposes, you can uncomment the following line to see the step-by-step computation + # G_blocks_tot = jnp.stack( + # [G_i(i) for i in range(self.num_segments)], axis=0 + # ) + + G_full = jnp.sum( + G_blocks_tot, axis=(0, 1) + ) # Sum over links and quadrature points + + return G_full + + def gravitational_vector( + self, + q: Array, + ) -> Array: + """ + Compute the gravitational vector of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + G (Array): Gravitational vector of shape (num_active_strains,). + """ + G_full = self._gravitational_full_vector(q) + + G = self.B_xi.T @ G_full + + return G + + def _stiffness_full_matrix( + self, + ) -> Array: + """ + Compute the full stiffness matrix of the robot. + + Returns: + K_full (Array): Full stiffness matrix of shape (num_strains, num_strains). + """ + K_full = self.stiffness_fn(formulate_in_strain_space=True) + + return K_full + + def stiffness_matrix( + self, + ) -> Array: + """ + Compute the stiffness matrix of the robot. + + Returns: + K (Array): Stiffness matrix of shape (num_active_strains, num_active_strains). + """ + K = self.stiffness_fn() + + return K + + def _damping_full_matrix( + self, + ) -> Array: + """ + Compute the full damping matrix of the robot. + + Args: + None + + Returns: + D (Array): Full damping matrix of shape (num_strains, num_strains). + """ + D_full = self.D + + return D_full + + def damping_matrix( + self, + ) -> Array: + """ + Compute the damping matrix of the robot. + + Args: + None + + Returns: + D (Array): Damping matrix of shape (num_active_strains, num_active_strains). + """ + D_full = self._damping_full_matrix() + + D = self.B_xi.T @ D_full @ self.B_xi + + return D + + def actuation_mapping( + self, + q: Array, + actuation_args: Optional[Tuple] = None, + ) -> Array: + """ + Compute the actuation mapping of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function, if any. + + Returns: + alpha (Array): Actuation mapping of shape (num_active_strains, num_active_strains). + """ + alpha = self.actuation_mapping_fn(q, *actuation_args) + + return alpha + + def dynamical_matrices( + self, + q: Array, + qd: Array, + actuation_args: Optional[Tuple] = None, + ) -> Tuple[Array, Array, Array, Array, Array, Array]: + """ + Compute the dynamical matrices of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function, if any. + + Returns: + B (Array): Inertia matrix of shape (num_active_strains, num_active_strains). + C (Array): Coriolis matrix of shape (num_active_strains, num_active_strains). + G (Array): Gravitational vector of shape (num_active_strains,). + K (Array): Stiffness matrix of shape (num_active_strains, num_active_strains). + D (Array): Damping matrix of shape (num_active_strains, num_active_strains). + alpha (Array): Actuation mapping of shape (num_active_strains, num_active_strains). + """ + B = self.inertia_matrix(q) + C = self.coriolis_matrix(q, qd) + G = self.gravitational_vector(q) + K = self.stiffness_matrix() + D = self.damping_matrix() + alpha = self.actuation_mapping(q, actuation_args) + + return B, C, G, K, D, alpha + + def kinetic_energy( + self, + q: Array, + qd: Array, + ) -> float: + """ + Compute the kinetic energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + T (float): Kinetic energy of the robot. + """ + B = self.inertia_matrix(q) + T = 0.5 * qd.T @ B @ qd + + return T + + def elastic_energy( + self, + q: Array, + ) -> float: + """ + Compute the elastic energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U_K (float): Elastic energy of the robot. + """ + K_full = self._stiffness_full_matrix() + U_K = 0.5 * (self.B_xi @ q).T @ K_full @ (self.B_xi @ q) + + return U_K + + def gravitational_energy( + self, + q: Array, + ) -> float: + """ + Compute the gravitational energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U_G (float): Gravitational energy of the robot. + """ + + def U_G_i(i): + Xs_scaled, Ws_scaled = scale_gaussian_quadrature( + self.Xs, self.Ws, self.L_cum[i], self.L_cum[i + 1] + ) + rho_i = self.rho[i] + A_i = self._local_cross_sectional_area(i) # Cross-sectional area + + def U_G_j(j): + Xs_j = Xs_scaled[j] + Ws_j = Ws_scaled[j] + p_j = ( + self.forward_kinematics(q, Xs_j).at[0].set(0.0) + ) # Set the orientation angle to 0 for gravitational energy computation + return Ws_j * rho_i * A_i * jnp.dot(p_j, self.g) + + U_G_blocks_segment_i = vmap(U_G_j)(jnp.arange(self.num_gauss_points)) + + return U_G_blocks_segment_i + + U_G_blocks_tot = vmap(U_G_i)(jnp.arange(self.num_segments)) + + U_G = jnp.sum(U_G_blocks_tot, axis=(0, 1)) # Sum over segments and Gauss points + + return U_G + + def potential_energy( + self, + q: Array, + ) -> float: + """ + Compute the potential energy of the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + + Returns: + U (float): Potential energy of the robot. + """ + U_K = self.elastic_energy(q) + U_G = self.gravitational_energy(q) + + return U_K + U_G + + def total_energy( + self, + q: Array, + qd: Array, + ) -> float: + """ + Compute the total energy of the robot, which is the sum of kinetic and potential energy. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + + Returns: + E (float): Total energy of the robot. + """ + T = self.kinetic_energy(q, qd) + U = self.potential_energy(q) + E = T + U + return E + + def operational_space_dynamical_matrices( + self, + q: Array, + qd: Array, + s: Array, + operational_space_selector: Tuple = (True, True, True), + ) -> Tuple[Array, Array, Array, Array, Array]: + """ + Compute the operational space dynamical matrices for the robot at a point s along the robot. + + Args: + q (Array): generalized coordinates of shape (num_active_strains,). + qd (Array): time-derivative of the generalized coordinates of shape (num_active_strains,). + s (Array): point coordinate along the robot in the interval [0, L]. + operational_space_selector (Tuple): Selector for the operational space dimensions. + Default is (True, True, True) for all dimensions. + + Returns: + Lambda (Array): Inertia matrix in the operational space, shape (num_operational_space_dims, num_operational_space_dims). + mu (Array): Coriolis and centrifugal matrix in the operational space, shape (num_operational_space_dims,). + J (Array): Jacobian of the forward kinematics at point s in the body frame, shape (num_operational_space_dims, num_active_strains). + J_d (Array): Time-derivative of the Jacobian at point s in the body frame, shape (num_operational_space_dims, num_active_strains). + JB_pinv (Array): Dynamically-consistent pseudo-inverse of the Jacobian, shape (num_active_strains, num_operational_space_dims). + """ + # classify the point along the robot to the corresponding segment + _, s_local = self.classify_segment(s) + + # make operational_space_selector a boolean array + operational_space_selector = jnp.array(operational_space_selector, dtype=bool) + + # Jacobian and its time-derivative + J, J_d = self.jacobian_and_derivative_bodyframe(q, qd, s_local) + + J = J[operational_space_selector, :] + J_d = J_d[operational_space_selector, :] + + # inverse of the inertia matrix in the configuration space + B = self.inertia_matrix(q) + B_inv = jnp.linalg.inv(B) + C = self.coriolis_matrix(q, qd) + + Lambda = jnp.linalg.inv( + J @ B_inv @ J.T + ) # inertia matrix in the operational space + mu = Lambda @ ( + J @ B_inv @ C - J_d + ) # coriolis and centrifugal matrix in the operational space + + JB_pinv = ( + B_inv @ J.T @ Lambda + ) # dynamically-consistent pseudo-inverse of the Jacobian + + return Lambda, mu, J, J_d, JB_pinv + + @eqx.filter_jit + def forward_dynamics( + self, + t: float, + y: Array, + actuation_args: Optional[Tuple] = None, + ) -> Array: + """ + Forward dynamics function. + + Args: + t (float): Current time. + y (Array): State vector containing configuration and velocity. + Shape is (2 * num_strains,). + actuation_args (Tuple, optional): Additional arguments for the actuation mapping function. + Default is None. + Returns: + y_d: Time derivative of the state vector. + """ + + q, qd = jnp.split( + y, 2 + ) # Split the state vector into configuration and velocity + + B, C, G, K, D, alpha = self.dynamical_matrices(q, qd, actuation_args) + + B_inv = jnp.linalg.inv(B) # Inverse of the inertia matrix + qdd = B_inv @ (-C @ qd - G - K @ q - D @ qd + alpha) # Compute the acceleration + + y_d = jnp.concatenate([qd, qdd]) + + return y_d + + def resolve_upon_time( + self, + q0: Array, + qd0: Array, + actuation_args: Optional[Tuple] = None, + t0: Optional[float] = 0.0, + t1: Optional[float] = 10.0, + dt: Optional[float] = 1e-4, + skip_steps: Optional[int] = 0, + solver: Optional[AbstractSolver] = Tsit5(), + stepsize_controller: Optional[PIDController] = ConstantStepSize(), + max_steps: Optional[int] = None, + ) -> Tuple[Array, Array, Array]: + """ + Resolve the system dynamics over time using Diffrax. + + Args: + q0 (Array): Initial configuration (strains). + qd0 (Array): Initial velocity (strains). + actuation_args (Tuple, optional): Additional arguments for the actuation function. + Default is None (no actuation). + t0 (float, optionnal): Initial time. + Default is 0.0. + t1 (float, optionnal): Final time. + Default is 10.0. + dt (float, optionnal): Time step for the solver. + Default is 1e-4. + skip_steps (int, optionnal): Number of steps to skip in the output. + This allows to reduce the number of saved time points. + Default is 0. + solver (AbstractSolver, optional): Solver to use for the ODE integration. + Default is Tsit5() (Runge-Kutta 5(4) method). + stepsize_controller (PIDController, optional): Stepsize controller for the solver. + Default is ConstantStepSize(). + max_steps (int, optional): Maximum number of steps for the solver. + Default is None (no limit). + + Returns: + ts (Array): Time points at which the solution is saved. + qs (Array): Configuration (strains) at the saved time points. + qds (Array): Velocity (strains) at the saved time points. + """ + y0 = jnp.concatenate([q0, qd0]) # Initial state vector + + term = ODETerm(self.forward_dynamics) + + t = jnp.arange(t0, t1, dt) # Time points for the solution + saveat = SaveAt(ts=t[::skip_steps]) # Save at specified time points + + sol = diffeqsolve( + terms=term, + solver=solver, + t0=t[0], + t1=t[-1], + dt0=dt, + y0=y0, + args=actuation_args, + saveat=saveat, + stepsize_controller=stepsize_controller, + max_steps=max_steps, + ) + + ts = sol.ts + # Extract the configuration and velocity from the solution + y_out = sol.ys + qs, qds = jnp.split(y_out, 2, axis=1) + + return ts, qs, qds diff --git a/src/jsrm/systems/planar_pcs_num.py b/src/jsrm/systems/planar_pcs_num.py deleted file mode 100644 index 8021a18..0000000 --- a/src/jsrm/systems/planar_pcs_num.py +++ /dev/null @@ -1,1768 +0,0 @@ -from jax import Array, lax, vmap -from jax import jacobian, grad -from jax import scipy as jscipy -from jax import numpy as jnp -from quadax import GaussKronrodRule - -import numpy as onp -from typing import Callable, Dict, Tuple, Optional, Literal, Any - -from .utils import ( - compute_strain_basis, - compute_planar_stiffness_matrix, - gauss_quadrature, -) -from jsrm.math_utils import blk_diag, blk_concat -from jsrm.utils.lie_operators import ( # To use SE(2) - Tangent_gn_SE2, - Adjoint_gn_SE2_inv, - Adjoint_g_SE2, - adjoint_SE2, - adjoint_star_SE2, -) -from jsrm.utils.lie_operators import ( - compute_weighted_sums, -) - -# To extract the interest coordinates and strains from the SE(3) elements -INTEREST_COORDINATES = jnp.array( - [2, 3, 4] -) # indices of the interest coordinates in the SE3 forward kinematics vector [theta_x, theta_y, theta_z, x, y, z] => [theta_z, x, y] -INTEREST_STRAIN = jnp.array( - [2, 3, 4] -) # indices of the interest strains in the SE3 strain vector [kappa_x, kappa_y, kappa_z, sigma_x, sigma_y, sigma_z] => [kappa_z, sigma_x, sigma_y] -# To reorder the lines to match the forward kinematics vector or strain vector -REORDERED_LINES_FWD_KINE = jnp.array( - [1, 2, 0] -) # reorder the lines to match the forward kinematics vector [x, y, theta] => [sigma_x, sigma_y, kappa_z] -REORDERED_LINES_STRAIN = jnp.array( - [2, 0, 1] -) # reorder the lines to match the strain vector [kappa_z, sigma_x, sigma_y] => [theta, x, y] - - -def factory( - num_segments: int, - strain_selector: Optional[Array] = None, - xi_eq: Optional[Array] = None, - stiffness_fn: Optional[Callable] = None, - actuation_mapping_fn: Optional[Callable] = None, - global_eps: float = jnp.finfo(jnp.float32).eps, - integration_type: Optional[ - Literal["gauss-legendre", "gauss-kronrad", "trapezoid"] - ] = "gauss-legendre", - param_integration: Optional[int] = None, - jacobian_type: Optional[Literal["explicit", "autodiff"]] = "explicit", -) -> Tuple[ - Array, - Callable[[Dict[str, Array], Array, Array, float], Array], - Callable[ - [Dict[str, Array], Array, Array, float], - Tuple[Array, Array, Array, Array, Array, Array], - ], - Dict[str, Callable[..., Any]], -]: - """ - Factory function to create the forward kinematics function for a planar robot. - This function computes the forward kinematics of a planar robot with a given number of segments. - - Args: - num_segments (int): number of segments in the robot. - strain_selector (Array, optional): strain selector array of shape (3 * num_segments, ) - specifying which strain components are active by setting them to True or False. - Defaults to None. - xi_eq (Array, optional): equilibrium strain vector of shape (3 * num_segments, ). - Defaults to 1 for the axial strain and 0 for the bending and shear strains. - stiffness_fn (Callable, optional): function to compute the stiffness matrix. - Defaults to None. - actuation_mapping_fn (Callable, optional): function to compute the actuation mapping. - Defaults to None. - global_eps (float, optional): small number to avoid singularities. - Defaults to 1e-8. - integration_type (str, optional): type of integration to use: "gauss-legendre", "gauss-kronrad" or "trapezoid". - Defaults to "gauss-legendre" for Gaussian quadrature. - param_integration (int, optional): parameter for the integration method. - If None, it is set to 5 for Gauss-Legendre quadrature, 15 for Gauss-Kronrad quadrature and 1000 for trapezoidal integration. - jacobian_type (str, optional): type of Jacobian to compute: "explicit" or "autodiff". - Defaults to "explicit" for explicit Jacobian computation. - - Returns: - B_xi (Array): strain basis matrix of shape (n_xi, n_q) where n_xi is the number of strains and n_q is the number of configuration variables. - forward_kinematics_fn (Callable): function to compute the forward kinematics of the robot. - takes in robot parameters params, configuration vector q, and point coordinate s along the robot, - and returns the pose of the robot at a given point along its length. - dynamical_matrices_fn (Callable): function to compute the dynamical matrices of the robot. - takes in robot parameters params, configuration vector q, configuration velocity q_d, - and returns the dynamical matrices B, C, G, K, D, alpha - auxiliary_fns (Dict[str, Callable]): dictionary of auxiliary functions for the robot. - - "apply_eps_to_bend_strains": function to apply a small number to the bending strains to avoid singularities. - - "classify_segment": function to classify a point along the robot to the corresponding segment. - - "stiffness_fn": function to compute the stiffness matrix of the robot. - - "actuation_mapping_fn": actuation_mapping_fn, - - "jacobian_fn": inertial-frame Jacobian of the forward kinematics function with respect to the strain vector. - - "kinetic_energy_fn": kinetic energy function of the robot. - - "potential_energy_fn": potential energy function of the robot. - - "energy_fn": total energy function of the robot. - - "operational_space_dynamical_matrices_fn": function to compute the operational space dynamical matrices of the robot. - """ - - # ======================================================================================================================= - # Initialize parameters if not provided - # ==================================================== - # Number of segments - if not isinstance(num_segments, int): - raise ValueError( - f"num_segments must be an integer, but got {type(num_segments)}" - ) - if num_segments < 1: - raise ValueError(f"num_segments must be greater than 0, but got {num_segments}") - - # Max number of degrees of freedom = size of the strain vector - n_xi = 3 * num_segments - - # Strain basis matrix - if strain_selector is None: - # activate all strains (i.e. bending, shear, and axial) - strain_selector = jnp.ones((n_xi,), dtype=bool) - if not isinstance(strain_selector, jnp.ndarray): - if isinstance(strain_selector, list): - strain_selector = jnp.array(strain_selector) - else: - raise TypeError( - f"strain_selector must be a jnp.ndarray, but got {type(strain_selector).__name__}" - ) - strain_selector = strain_selector.flatten() - if strain_selector.shape[0] != n_xi: - raise ValueError( - f"strain_selector must have the same shape as the strain vector, but got {strain_selector.shape[0]} instead of {n_xi}" - ) - if not jnp.issubdtype(strain_selector.dtype, jnp.bool_): - raise TypeError( - f"strain_selector must be a boolean array, but got {strain_selector.dtype}" - ) - - # Rest strain - if xi_eq is None: - xi_eq = jnp.zeros((n_xi,)) - # By default, set the axial rest strain (local y-axis) along the entire rod to 1.0 - rest_strain_reshaped = xi_eq.reshape((-1, 3)) - rest_strain_reshaped = rest_strain_reshaped.at[:, -1].set(1.0) - xi_eq = rest_strain_reshaped.flatten() - if not isinstance(xi_eq, jnp.ndarray): - if isinstance(xi_eq, list): - xi_eq = jnp.array(xi_eq) - else: - raise TypeError( - f"xi_eq must be a jnp.ndarray, but got {type(xi_eq).__name__}" - ) - xi_eq = xi_eq.flatten() - if xi_eq.shape[0] != n_xi: - raise ValueError( - f"xi_eq must have the same shape as the strain vector, but got {xi_eq.shape[0]} instead of {n_xi}" - ) - if not jnp.issubdtype(xi_eq.dtype, jnp.floating): - if not jnp.issubdtype(xi_eq.dtype, jnp.integer): - raise TypeError( - f"xi_eq must be a floating point array, but got {xi_eq.dtype}" - ) - else: - xi_eq = xi_eq.astype(jnp.float32) - - # Stiffness function - compute_stiffness_matrix_for_all_segments_fn = vmap(compute_planar_stiffness_matrix) - if stiffness_fn is None: - - def stiffness_fn( - params: Dict[str, Array], - B_xi: Array, - formulate_in_strain_space: bool = False, - ) -> Array: - """ - Compute the stiffness matrix of the system. - Args: - params: Dictionary of robot parameters - B_xi: Strain basis matrix - formulate_in_strain_space: - whether to formulate the elastic matrix in the strain space - Returns: - S: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise - """ - # length of the segments - l = params["l"] - # cross-sectional area and second moment of area - A = jnp.pi * params["r"] ** 2 - Ib = A**2 / (4 * jnp.pi) - - # elastic and shear modulus - E, G = params["E"], params["G"] - # stiffness matrix of shape (num_segments, 3, 3) - S_sms = compute_stiffness_matrix_for_all_segments_fn(l, A, Ib, E, G) - # we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to - S = blk_diag(S_sms) - - if not formulate_in_strain_space: - S = B_xi.T @ S @ B_xi - - return S - - if not callable(stiffness_fn): - raise TypeError( - f"stiffness_fn must be a callable, but got {type(stiffness_fn).__name__}" - ) - - # Actuation mapping function - if actuation_mapping_fn is None: - - def actuation_mapping_fn( - forward_kinematics_fn: Callable, - jacobian_fn: Callable, - params: Dict[str, Array], - B_xi: Array, - q: Array, - eps: float = global_eps, - ) -> Array: - """ - Returns the actuation matrix that maps the actuation space to the configuration space. - Assumes the fully actuated and identity actuation matrix. - Args: - forward_kinematics_fn: function to compute the forward kinematics - jacobian_fn: function to compute the Jacobian - params: dictionary with robot parameters - B_xi: strain basis matrix - q: configuration of the robot - eps: small number to avoid singularities (default: global_eps = 1e-8) - Returns: - A: actuation matrix of shape (n_xi, n_xi) where n_xi is the number of strains. - """ - A = B_xi.T @ jnp.identity(n_xi) @ B_xi - - return A - - if not callable(actuation_mapping_fn): - raise TypeError( - f"actuation_mapping_fn must be a callable, but got {type(actuation_mapping_fn).__name__}" - ) - - if integration_type == "gauss-legendre": - if param_integration is None: - param_integration = 5 - elif integration_type == "gauss-kronrad": - if param_integration is None: - param_integration = 15 - if param_integration not in [15, 21, 31, 41, 51, 61]: - raise ValueError( - f"param_integration must be one of [15, 21, 31, 41, 51, 61] for gauss-kronrad integration, but got {param_integration}" - ) - elif integration_type == "trapezoid": - if param_integration is None: - param_integration = 1000 - else: - raise ValueError( - f"integration_type must be either 'gauss-legendre', 'gauss-kronrad' or 'trapezoid', but got {integration_type}" - ) - - if jacobian_type not in ["explicit", "autodiff"]: - raise ValueError( - f"jacobian_type must be either 'explicit' or 'autodiff', but got {jacobian_type}" - ) - - # ======================================================================================================================= - # Define the functions - # ==================================================== - - # Compute the strain basis matrix - B_xi = compute_strain_basis(strain_selector) - - def apply_eps_to_bend_strains(xi: Array, eps: float) -> Array: - """ - Add a small number to the bending strain to avoid singularities. - - Args: - xi (Array): strain vector of the robot. - eps (float): small number to add to the bending strain. - - Returns: - Array: strain vector with the bending strain modified. - """ - if eps == None: - return xi - else: - xi_reshaped = xi.reshape((-1, 3)) - - xi_bend_sign = jnp.sign(xi_reshaped[:, 0]) - - # set zero sign to 1 (i.e. positive) - xi_bend_sign = jnp.where(xi_bend_sign == 0, 1, xi_bend_sign) - - # add eps to the bending strain (i.e. the first column) - sigma_b_epsed = lax.select( - jnp.abs(xi_reshaped[:, 0]) < eps, - xi_bend_sign * eps, - xi_reshaped[:, 0], - ) - xi_epsed = jnp.stack( - [ - sigma_b_epsed, - xi_reshaped[:, 1], - xi_reshaped[:, 2], - ], - axis=1, - ) - - # Flatten the array - xi_epsed = xi_epsed.flatten() - - return xi_epsed - - def classify_segment( - params: Dict[str, Array], s: Array - ) -> Tuple[Array, Array, Array]: - """ - Classify the point along the robot to the corresponding segment. - - Args: - params (Dict[str, Array]): dictionary of robot parameters - s (Array): point coordinate along the robot in the interval [0, L]. - - Returns: - segment_idx (Array): index of the segment where the point is located - s_segment (Array): point coordinate along the segment in the interval [0, l_segment] - l_cum (Array): cumulative length of the segments starting with 0 - """ - l = params["l"] - - # Compute the cumulative length of the segments starting with 0 - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Classify the point along the robot to the corresponding segment - segment_idx = jnp.clip(jnp.sum(s > l_cum) - 1, 0, len(l) - 1) - - # Compute the point coordinate along the segment in the interval [0, l_segment] - s_segment = s - l_cum[segment_idx] - - return segment_idx, s_segment.squeeze(), l_cum - - def chi_fn( - params: Dict[str, Array], xi: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the pose of the robot at a given point along its length with respect to the strain vector. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - chi (Array): pose of the robot at the point s in the interval [0, L] - The pose is represented as a vector [x, y, theta] where (x, y) is the position - and theta is the orientation angle. - """ - th0 = params["th0"] # initial angle of the robot - l = params["l"] # length of each segment [m] - - # Classify the point along the robot to the corresponding segment - segment_idx, s_local, _ = classify_segment(params, s) - - chi_0 = jnp.concatenate([jnp.zeros(2), th0[None]]) # Initial pose of the robot - - # Iteration function - def chi_i(i: int, chi_prev: Array) -> Array: - th_prev = chi_prev[2] # Extract previous orientation angle from val - p_prev = chi_prev[:2] # Extract previous position from val - - # Extract strains for the current segment - kappa = xi[3 * i + 0] # Bending strain - sigma_x = xi[3 * i + 1] # Shear strain - sigma_y = xi[3 * i + 2] # Axial strain - - # Compute the length of the current segment to integrate - l_i = jnp.where(i == segment_idx, s_local, l[i]) - - # Compute the orientation angle for the current segment - dth = kappa * l_i # Angle increment for the current segment - th = th_prev + dth - - # Compute the integrals for the transformation matrix - int_cos_th = jnp.where( - jnp.abs(kappa) < eps, - l_i * jnp.cos(th_prev), - (jnp.sin(th) - jnp.sin(th_prev)) / kappa, - ) - - int_sin_th = jnp.where( - jnp.abs(kappa) < eps, - l_i * jnp.sin(th_prev), - (jnp.cos(th_prev) - jnp.cos(th)) / kappa, - ) - - # Transformation matrix - R = jnp.stack( - [ - jnp.stack([int_cos_th, -int_sin_th]), - jnp.stack([int_sin_th, int_cos_th]), - ] - ) - - # Compute the position - p = p_prev + R @ jnp.stack([sigma_x, sigma_y], axis=-1) - - return jnp.concatenate([p, th[None]]) - - _, chi_list = lax.scan( - f=lambda carry, i: (chi_i(i, carry), chi_i(i, carry)), - init=chi_0, - xs=jnp.arange(num_segments + 1), - ) - - chi = chi_list[segment_idx] - - return chi - - def forward_kinematics_fn( - params: Dict[str, Array], q: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the forward kinematics of the robot. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): configuration vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - chi (Array): pose of the robot at the point s in the interval [0, L]. - The pose is represented as a vector [x, y, theta] where (x, y) is the position - and theta is the orientation angle. - """ - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - - chi = chi_fn(params, xi, s, eps) - - return chi - - def J_autodiff( - params: Dict[str, Array], xi: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the inertial-frame jacobian of the forward kinematics function with respect to the strain vector - using autodiff. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - J (Array): inertial-frame jacobian of the forward kinematics function with respect to the strain vector. - """ - # Compute the Jacobian of chi_fn with respect to xi - J = jacobian(lambda _xi: chi_fn(params, _xi, s, eps))(xi) - - # apply the strain basis to the Jacobian - J = J @ B_xi - - return J - - def J_explicit_local( - params: Dict[str, Array], xi: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the body-frame jacobian of the forward kinematics function with respect to the strain vector - at a given point s using explicit expression in SE(2). - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L_tot]. - - Returns: - J_local (Array) : body-frame jacobian of the forward kinematics function with respect to the strain vector - using explicit expression in SE(2). - """ - - # Classify the point along the robot to the corresponding segment - segment_idx, _, l_cum = classify_segment(params, s) - - xi = apply_eps_to_bend_strains(xi, eps) # Apply eps to the bending strain - - # Initial condition - xi_SE2_0 = xi[0:3] - - Ad_g0_inv_L0 = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], l_cum[1]) - Ad_g0_inv_s = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], s) - - T_g0_L0 = Tangent_gn_SE2(xi_SE2_0, l_cum[0], l_cum[1]) - T_g0_s = Tangent_gn_SE2(xi_SE2_0, l_cum[0], s) - - mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 - mat_0_s = Ad_g0_inv_s @ T_g0_s - - J_0_L0 = jnp.concatenate( - [mat_0_L0[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - J_0_s = jnp.concatenate( - [mat_0_s[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - - tuple_J_0 = (J_0_L0, J_0_s) - - # Iteration function - def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: - J_prev_Lprev, _ = tuple_J_prev - - start_index = 3 * i - xi_SE2_i = lax.dynamic_slice(xi, (start_index,), (3,)) - - Ad_gi_inv_Li = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], l_cum[i + 1]) - Ad_gi_inv_s = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], s) - - T_gi_Li = Tangent_gn_SE2(xi_SE2_i, l_cum[i], l_cum[i + 1]) - T_gi_s = Tangent_gn_SE2(xi_SE2_i, l_cum[i], s) - - mat_i_Li = Ad_gi_inv_Li @ T_gi_Li - mat_i_s = Ad_gi_inv_s @ T_gi_s - - J_new_s = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), - mat_i_s[jnp.newaxis, ...], - (i, 0, 0), - ) - J_new_Li = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), - mat_i_Li[jnp.newaxis, ...], - (i, 0, 0), - ) - - tuple_J_new = (J_new_Li, J_new_s) - - return tuple_J_new, J_new_s # We accumulate J_new_s - - _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=jnp.arange(1, num_segments)) - - # Add the initial condition to the Jacobian array - J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) - - # Extract the Jacobian for the segment that contains the point s - J_segment_SE2_local = lax.dynamic_index_in_dim( - J_array, segment_idx, axis=0, keepdims=False - ) - - # Reorder the lines to match the forward kinematics function - J_local = J_segment_SE2_local[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - - J_local = blk_concat(J_local) # shape: (n_segments*3, 3) - J_local = J_local @ B_xi - - return J_local - - def J_explicit_global( - params: Dict[str, Array], xi: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the inertial-frame jacobian of the forward kinematics function with respect to the strain vector - at a given point s using explicit expression in SE(2). - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L_tot]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - J_global (Array): inertial-frame jacobian of the forward kinematics function with respect to the strain vector - using explicit expression in SE(2). - """ - - # Classify the point along the robot to the corresponding segment - segment_idx, _, l_cum = classify_segment(params, s) - - xi = apply_eps_to_bend_strains(xi, eps) # Apply eps to the bending strain - - # Initial condition - xi_SE2_0 = xi[0:3] - - Ad_g0_inv_L0 = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], l_cum[1]) - Ad_g0_inv_s = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], s) - - T_g0_L0 = Tangent_gn_SE2(xi_SE2_0, l_cum[0], l_cum[1]) - T_g0_s = Tangent_gn_SE2(xi_SE2_0, l_cum[0], s) - - mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 - mat_0_s = Ad_g0_inv_s @ T_g0_s - - J_0_L0 = jnp.concatenate( - [mat_0_L0[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - J_0_s = jnp.concatenate( - [mat_0_s[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - - tuple_J_0 = (J_0_L0, J_0_s) - - # Iteration function - def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: - J_prev_Lprev, _ = tuple_J_prev - - start_index = 3 * i - xi_SE2_i = lax.dynamic_slice(xi, (start_index,), (3,)) - - Ad_gi_inv_Li = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], l_cum[i + 1]) - Ad_gi_inv_s = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], s) - - T_gi_Li = Tangent_gn_SE2(xi_SE2_i, l_cum[i], l_cum[i + 1]) - T_gi_s = Tangent_gn_SE2(xi_SE2_i, l_cum[i], s) - - mat_i_Li = Ad_gi_inv_Li @ T_gi_Li - mat_i_s = Ad_gi_inv_s @ T_gi_s - - J_new_s = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), - mat_i_s[jnp.newaxis, ...], - (i, 0, 0), - ) - J_new_Li = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), - mat_i_Li[jnp.newaxis, ...], - (i, 0, 0), - ) - - tuple_J_new = (J_new_Li, J_new_s) - - return tuple_J_new, J_new_s # We accumulate J_new_s - - _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=jnp.arange(1, num_segments)) - - # Add the initial condition to the Jacobian array - J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) - - # Extract the Jacobian for the segment that contains the point s - J_segment_SE2_local = lax.dynamic_index_in_dim( - J_array, segment_idx, axis=0, keepdims=False - ) - - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # From local to global frame : applying the rotation of the pose at point s - - # Get the pose at point s - _, _, theta = chi_fn(params, xi, s, eps) - # Convert the pose to SE(3) representation - c, s = jnp.cos(theta), jnp.sin(theta) - R = jnp.stack([jnp.stack([c, -s]), jnp.stack([s, c])]) - g_s = jnp.block([[R, jnp.zeros((2, 1))], [jnp.zeros((1, 2)), jnp.eye(1)]]) - Adjoint_g_s = Adjoint_g_SE2(g_s) - # For each segment, compute the Jacobian in SE(3) coordinates in global frame J_i_global = Adjoint_g_s @ J_i_local - J_segment_SE2_global = jnp.einsum( - "ij,njk->nik", Adjoint_g_s, J_segment_SE2_local - ) # shape: (n_segments, 6, 6) - - # Reorder the lines to match the forward kinematics function - J_global = J_segment_SE2_global[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - - J_global = blk_concat(J_global) # shape: (n_segments*3, 3) - J_global = J_global @ B_xi - - return J_global - - def J_Jd_autodiff( - params: Dict[str, Array], - xi: Array, - xi_d: Array, - s: Array, - eps: float = global_eps, - ) -> Tuple[Array, Array]: - """ - Compute the inertial-frame jacobian of the forward kinematics function and its time-derivative using autodiff. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - xi_d (Array): velocity vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - - Returns: - J (Array): inertial-frame jacobian of the forward kinematics function. - J_d (Array): inertial-frame time-derivative of the Jacobian of the forward kinematics function. - """ - - xi = apply_eps_to_bend_strains(xi, eps) # Apply eps to the bending strain - - # Compute the Jacobian of chi_fn with respect to xi - J = jacobian(lambda _xi: chi_fn(params, _xi, s, eps))(xi) - - dJ_dxi = jacobian(J)(xi) - J_d = jnp.tensordot(dJ_dxi, xi_d, axes=([2], [0])) - - # apply the strain basis to the Jacobian - J = J @ B_xi - - # apply the strain basis to the time-derivative of the Jacobian - J_d = J_d @ B_xi - - return J, J_d - - def J_Jd_explicit_local( - params: Dict[str, Array], - xi: Array, - xi_d: Array, - s: Array, - eps: float = global_eps, - ) -> Tuple[Array, Array]: - """ - Compute the body-frame jacobian and its derivative with respect to the strain vector - at a given point s using explicit expression in SE(2). - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - xi_d (Array): velocity vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - J_local (Array): body-frame jacobian of the forward kinematics function. - J_d_local (Array): body-frame time-derivative of the jacobian of the forward kinematics function. - """ - # Classify the point along the robot to the corresponding segment - segment_idx, _, l_cum = classify_segment(params, s) - - xi = apply_eps_to_bend_strains(xi, eps) # Apply eps to the bending strain - - # ================================= - # Computation of the Jacobian - - # Initial condition - xi_SE2_0 = xi[0:3] - - Ad_g0_inv_L0 = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], l_cum[1]) - Ad_g0_inv_s = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], s) - - T_g0_L0 = Tangent_gn_SE2(xi_SE2_0, l_cum[0], l_cum[1]) - T_g0_s = Tangent_gn_SE2(xi_SE2_0, l_cum[0], s) - - mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 - mat_0_s = Ad_g0_inv_s @ T_g0_s - - J_0_L0 = jnp.concatenate( - [mat_0_L0[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - J_0_s = jnp.concatenate( - [mat_0_s[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - - tuple_J_0 = (J_0_L0, J_0_s) - - # Iteration function - def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: - J_prev_Lprev, _ = tuple_J_prev - - start_index = 3 * i - xi_SE2_i = lax.dynamic_slice(xi, (start_index,), (3,)) - - Ad_gi_inv_Li = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], l_cum[i + 1]) - Ad_gi_inv_s = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], s) - - T_gi_Li = Tangent_gn_SE2(xi_SE2_i, l_cum[i], l_cum[i + 1]) - T_gi_s = Tangent_gn_SE2(xi_SE2_i, l_cum[i], s) - - mat_i_Li = Ad_gi_inv_Li @ T_gi_Li - mat_i_s = Ad_gi_inv_s @ T_gi_s - - J_new_s = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), - mat_i_s[jnp.newaxis, ...], - (i, 0, 0), - ) - J_new_Li = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), - mat_i_Li[jnp.newaxis, ...], - (i, 0, 0), - ) - - tuple_J_new = (J_new_Li, J_new_s) - - return tuple_J_new, J_new_s # We accumulate J_new_s - - _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=jnp.arange(1, num_segments)) - - # Add the initial condition to the Jacobian array - J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) - - # Extract the Jacobian for the segment that contains the point s - J_segment_SE2_local = lax.dynamic_index_in_dim( - J_array, segment_idx, axis=0, keepdims=False - ) - - # ================================= - # Computation of the time-derivative of the Jacobian - - idx_range = jnp.arange(num_segments) - xi_d_SE2_i = vmap(lambda i: lax.dynamic_slice(xi_d, (3 * i,), (3,)))( - idx_range - ) # shape: (num_segments, 3) - S_i = vmap( - lambda i: lax.dynamic_index_in_dim( - J_segment_SE2_local, i, axis=0, keepdims=False - ) - )(idx_range) # shape: (num_segments, 3, 3) - sum_Sj_xi_d_j = compute_weighted_sums( - J_segment_SE2_local, xi_d_SE2_i, num_segments - ) # shape: (num_segments, 3) - adjoint_sum = vmap(adjoint_SE2)(sum_Sj_xi_d_j) # shape: (num_segments, 3, 3) - - # Compute the time-derivative of the Jacobian - J_d_segment_SE2_local = jnp.einsum( - "ijk, ikl->ijl", adjoint_sum, S_i - ) # shape: (num_segments, 3, 3) - - # Replace the elements of J_d_segment_SE2 for i > segment_idx by null matrices - J_d_segment_SE2_local = jnp.where( - jnp.arange(num_segments)[:, None, None] > segment_idx, - jnp.zeros_like(J_d_segment_SE2_local), - J_d_segment_SE2_local, - ) - - # Reorder the lines to match the forward kinematics function - J_local = J_segment_SE2_local[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - J_d_local = J_d_segment_SE2_local[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - - # Concatenate the Jacobians and their time-derivatives for all segments - J_local = blk_concat(J_local) # shape: (n_segments*3, 3) - J_d_local = blk_concat(J_d_local) # shape: (n_segments*3, 3) - - # Apply the strain basis to the Jacobian and its time-derivative - J_local = J_local @ B_xi - J_d_local = J_d_local @ B_xi - - return J_local, J_d_local - - def J_Jd_explicit_global( - params: Dict[str, Array], - xi: Array, - xi_d: Array, - s: Array, - eps: float = global_eps, - ) -> Tuple[Array, Array]: - """ - Compute the inertial-frame jacobian and its derivative with respect to the strain vector - at a given point s using explicit expression in SE(2). - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - xi_d (Array): velocity vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - J_global (Array): inertial-frame jacobian of the forward kinematics function. - J_d_global (Array): inertial-frame time-derivative of the jacobian of the forward kinematics function. - """ - # Classify the point along the robot to the corresponding segment - segment_idx, _, l_cum = classify_segment(params, s) - - xi = apply_eps_to_bend_strains(xi, eps) # Apply eps to the bending strain - - # ================================= - # Computation of the Jacobian - - # Initial condition - xi_SE2_0 = xi[0:3] - - Ad_g0_inv_L0 = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], l_cum[1]) - Ad_g0_inv_s = Adjoint_gn_SE2_inv(xi_SE2_0, l_cum[0], s) - - T_g0_L0 = Tangent_gn_SE2(xi_SE2_0, l_cum[0], l_cum[1]) - T_g0_s = Tangent_gn_SE2(xi_SE2_0, l_cum[0], s) - - mat_0_L0 = Ad_g0_inv_L0 @ T_g0_L0 - mat_0_s = Ad_g0_inv_s @ T_g0_s - - J_0_L0 = jnp.concatenate( - [mat_0_L0[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - J_0_s = jnp.concatenate( - [mat_0_s[None, :, :], jnp.zeros((num_segments - 1, 3, 3))], axis=0 - ) - - tuple_J_0 = (J_0_L0, J_0_s) - - # Iteration function - def J_i(tuple_J_prev: Array, i: int) -> Tuple[Tuple[Array, Array], Array]: - J_prev_Lprev, _ = tuple_J_prev - - start_index = 3 * i - xi_SE2_i = lax.dynamic_slice(xi, (start_index,), (3,)) - - Ad_gi_inv_Li = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], l_cum[i + 1]) - Ad_gi_inv_s = Adjoint_gn_SE2_inv(xi_SE2_i, l_cum[i], s) - - T_gi_Li = Tangent_gn_SE2(xi_SE2_i, l_cum[i], l_cum[i + 1]) - T_gi_s = Tangent_gn_SE2(xi_SE2_i, l_cum[i], s) - - mat_i_Li = Ad_gi_inv_Li @ T_gi_Li - mat_i_s = Ad_gi_inv_s @ T_gi_s - - J_new_s = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_s, J_prev_Lprev), - mat_i_s[jnp.newaxis, ...], - (i, 0, 0), - ) - J_new_Li = lax.dynamic_update_slice( - jnp.einsum("ij, njk->nik", Ad_gi_inv_Li, J_prev_Lprev), - mat_i_Li[jnp.newaxis, ...], - (i, 0, 0), - ) - - tuple_J_new = (J_new_Li, J_new_s) - - return tuple_J_new, J_new_s # We accumulate J_new_s - - _, J_array = lax.scan(f=J_i, init=tuple_J_0, xs=jnp.arange(1, num_segments)) - - # Add the initial condition to the Jacobian array - J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0) - - # Extract the Jacobian for the segment that contains the point s - J_segment_SE2_local = lax.dynamic_index_in_dim( - J_array, segment_idx, axis=0, keepdims=False - ) - - # ================================= - # Computation of the time-derivative of the Jacobian - - idx_range = jnp.arange(num_segments) - xi_d_SE2_i = vmap(lambda i: lax.dynamic_slice(xi_d, (3 * i,), (3,)))( - idx_range - ) # shape: (num_segments, 3) - S_i = vmap( - lambda i: lax.dynamic_index_in_dim( - J_segment_SE2_local, i, axis=0, keepdims=False - ) - )(idx_range) # shape: (num_segments, 3, 3) - sum_Sj_xi_d_j = compute_weighted_sums( - J_segment_SE2_local, xi_d_SE2_i, num_segments - ) # shape: (num_segments, 3) - adjoint_sum = vmap(adjoint_SE2)(sum_Sj_xi_d_j) # shape: (num_segments, 3, 3) - - # Compute the time-derivative of the Jacobian - J_d_segment_SE2_local = jnp.einsum( - "ijk, ikl->ijl", adjoint_sum, S_i - ) # shape: (num_segments, 3, 3) - - # Replace the elements of J_d_segment_SE2 for i > segment_idx by null matrices - J_d_segment_SE2_local = jnp.where( - jnp.arange(num_segments)[:, None, None] > segment_idx, - jnp.zeros_like(J_d_segment_SE2_local), - J_d_segment_SE2_local, - ) - - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # From local to global frame : applying the rotation of the pose at point s - - # Get the pose at point s - _, _, theta = chi_fn(params, xi, s, eps) - # Convert the pose to SE(3) representation - c, s = jnp.cos(theta), jnp.sin(theta) - R = jnp.stack([jnp.stack([c, -s]), jnp.stack([s, c])]) - g_s = jnp.block([[R, jnp.zeros((2, 1))], [jnp.zeros((1, 2)), jnp.eye(1)]]) - Adjoint_g_s = Adjoint_g_SE2(g_s) - # For each segment, compute the Jacobian in SE(3) coordinates in global frame J_i_global = Adjoint_g_s @ J_i_local - J_segment_SE2_global = jnp.einsum( - "ij,njk->nik", Adjoint_g_s, J_segment_SE2_local - ) # shape: (n_segments, 6, 6) - J_d_segment_global = jnp.einsum( - "ij,njk->nik", Adjoint_g_s, J_d_segment_SE2_local - ) # shape: (n_segments, 6, 6) - - # Reorder the lines to match the forward kinematics function - J_global = J_segment_SE2_global[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - J_d_global = J_d_segment_global[ - :, REORDERED_LINES_FWD_KINE, : - ] # shape: (n_segments, 3, 3) - - # Concatenate the Jacobians and their time-derivatives for all segments - J_global = blk_concat(J_global) # shape: (n_segments*3, 3) - J_d_global = blk_concat(J_d_global) # shape: (n_segments*3, 3) - - # Apply the strain basis to the Jacobian and its time-derivative - J_global = J_global @ B_xi - J_d_global = J_d_global @ B_xi - - return J_global, J_d_global - - if jacobian_type == "explicit": - jacobian_fn_xi = J_explicit_global - J_Jd = J_Jd_explicit_global - elif jacobian_type == "autodiff": - jacobian_fn_xi = J_autodiff - J_Jd = J_Jd_autodiff - - def jacobian_fn( - params: Dict[str, Array], q: Array, s: Array, eps: float = global_eps - ) -> Array: - """ - Compute the inertial-frame jacobian of the forward kinematics function with respect to the configuration vector q. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): configuration vector of the robot. - s (Array): point coordinate along the robot in the interval [0, L]. - eps (float, optional): small number to avoid singularities. Defaults to global_eps = 1e-8. - - Returns: - J (Array): inertial-frame jacobian of the forward kinematics function with respect to the strain vector. - """ - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - - # Add a small number to the bending strain to avoid singularities - xi = apply_eps_to_bend_strains(xi, eps) - - # Compute the Jacobian of chi_fn with respect to xi - J = jacobian_fn_xi(params, xi, s) - return J - - def B_autodiff_fn(params: Dict[str, Array], xi: Array) -> Array: - """ - Compute the mass / inertia matrix of the robot using autodiff. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - - Returns: - B (Array): mass / inertia matrix of the robot. - """ - # Extract the parameters - rho = params["rho"] # density of each segment [kg/m^3] - l = params["l"] # length of each segment [m] - r = params["r"] # radius of each segment [m] - - # Usefull derived quantities - A = jnp.pi * r**2 # cross-sectional area of each segment [m^2] - Ib = A**2 / (4 * jnp.pi) # second moment of area of each segment [m^4] - - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Compute each integral - def compute_integral(i): - if integration_type == "gauss-legendre": - Xs, Ws, nGauss = gauss_quadrature( - N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1] - ) - - J_all = vmap(lambda s: J_autodiff(params, xi, s))(Xs) - Jp_all = J_all[:, :2, :] - Jo_all = J_all[:, 2:, :] - - integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all) - integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all) - - integral_Jp = jnp.sum(Ws[:, None, None] * integrand_JpT_Jp, axis=0) - integral_Jo = jnp.sum(Ws[:, None, None] * integrand_JoT_Jo, axis=0) - - integral = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo - - elif integration_type == "gauss-kronrad": - rule = GaussKronrodRule(order=param_integration) - - def integrand(s): - J = J_autodiff(params, xi, s) - Jp = J[:2, :] - Jo = J[2:, :] - return rho[i] * A[i] * Jp.T @ Jp + rho[i] * Ib[i] * Jo.T @ Jo - - integral, _, _, _ = rule.integrate( - integrand, l_cum[i], l_cum[i + 1], args=() - ) - - elif integration_type == "trapezoid": - xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration) - - J_all = vmap(lambda s: J_autodiff(params, xi, s))(xs) - Jp_all = J_all[:, :2, :] - Jo_all = J_all[:, 2:, :] - - integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all) - integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all) - - integral_Jp = jscipy.integrate.trapezoid(integrand_JpT_Jp, x=xs, axis=0) - integral_Jo = jscipy.integrate.trapezoid(integrand_JoT_Jo, x=xs, axis=0) - - integral = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo - - return integral - - # Compute the cumulative integral - indices = jnp.arange(num_segments) - integrals = vmap(compute_integral)(indices) - - B = jnp.sum(integrals, axis=0) - - return B - - def C_autodiff_fn( - params: Dict[str, Array], xi: Array, xi_d: Array - ) -> Tuple[Array, Array]: - """ - Compute the Coriolis / centrifugal matrix of the robot - using autodiff. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - xi_d (Array): velocity vector of the robot. - - Returns: - C (Array): Coriolis / centrifugal matrix of the robot. - """ - - n_xi = 3 * num_segments - - def christoffel_fn(i, j, k): - dB_ij = grad(lambda x: B_autodiff_fn(params, x)[i, j])(xi)[k] - dB_ik = grad(lambda x: B_autodiff_fn(params, x)[i, k])(xi)[j] - dB_jk = grad(lambda x: B_autodiff_fn(params, x)[j, k])(xi)[i] - return 0.5 * (dB_ij + dB_ik - dB_jk) - - # # =========================================== - # # Native for loop version - # C = jnp.zeros((n_xi, n_xi)) # Initialize the Coriolis matrix - # for i in range(n_xi): - # for j in range(n_xi): - # for k in range(n_xi): - # christoffel_symbol = christoffel_fn(i, j, k) - # C = C.at[i,j].add(christoffel_symbol * xi_d[k]) - - # # =========================================== - # # For loop with LAX version - # C = jnp.zeros((n_xi, n_xi)) # Initialize the Coriolis matrix - # def body_i(i, C): - # def body_j(j, C): - # def body_k(k, acc): - # christoffel_symbol = christoffel_fn(i, j, k) - # coeff = christoffel_symbol * xi_d[k] - # return acc + coeff - # C_ij = lax.fori_loop(0, n_xi, body_k, 0.0) - # return C.at[i, j].set(C_ij) - # return lax.fori_loop(0, n_xi, body_j, C) - # C = lax.fori_loop(0, n_xi, body_i, C) - - # =========================================== - # Vmap version - def C_ij(i, j): - cs_k = vmap(lambda k: christoffel_fn(i, j, k))(jnp.arange(n_xi)) - return jnp.dot(cs_k, xi_d) - - C = vmap(lambda i: vmap(lambda j: C_ij(i, j))(jnp.arange(n_xi)))( - jnp.arange(n_xi) - ) - - # # =========================================== - # # LAX map version - # def C_ij(i, j, xi_d, n_xi): - # cs_k = lax.map(lambda k: christoffel_fn(i, j, k), jnp.arange(n_xi)) - # return jnp.dot(cs_k, xi_d) - # C = jnp.stack( - # jnp.stack( - # lax.map( - # lambda i: lax.map( - # lambda j: C_ij(i, j, xi_d, n_xi), - # jnp.arange(n_xi) - # ), - # jnp.arange(n_xi)) - # ) - # ) - - return C - - def B_explicit_fn(params: Dict[str, Array], xi: Array) -> Array: - """ - Compute the mass / inertia matrix of the robot using explicit expression. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - - Returns: - B (Array): mass / inertia matrix of the robot. - """ - # Extract the parameters - rho = params["rho"] # density of each segment [kg/m^3] - l = params["l"] # length of each segment [m] - r = params["r"] # radius of each segment [m] - - # Usefull derived quantities - A = jnp.pi * r**2 # cross-sectional area of each segment [m^2] - Ib = A**2 / (4 * jnp.pi) # second moment of area of each segment [m^4] - - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Compute each integral - def compute_integral(i): - if integration_type == "gauss-legendre": - Xs, Ws, nGauss = gauss_quadrature( - N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1] - ) - - J_all = vmap(lambda s: J_explicit_local(params, xi, s))(Xs) - Jp_all = J_all[:, :2, :] - Jo_all = J_all[:, 2:, :] - - integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all) - integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all) - - integral_Jp = jnp.sum(Ws[:, None, None] * integrand_JpT_Jp, axis=0) - integral_Jo = jnp.sum(Ws[:, None, None] * integrand_JoT_Jo, axis=0) - - integral = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo - - elif integration_type == "gauss-kronrad": - rule = GaussKronrodRule(order=param_integration) - - def integrand(s): - J = J_explicit_local(params, xi, s) - Jp = J[:2, :] - Jo = J[2:, :] - return rho[i] * A[i] * Jp.T @ Jp + rho[i] * Ib[i] * Jo.T @ Jo - - integral, _, _, _ = rule.integrate( - integrand, l_cum[i], l_cum[i + 1], args=() - ) - - elif integration_type == "trapezoid": - xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration) - - J_all = vmap(lambda s: J_explicit_local(params, xi, s))(xs) - Jp_all = J_all[:, :2, :] - Jo_all = J_all[:, 2:, :] - - integrand_JpT_Jp = jnp.einsum("nij,nik->njk", Jp_all, Jp_all) - integrand_JoT_Jo = jnp.einsum("nij,nik->njk", Jo_all, Jo_all) - - integral_Jp = jscipy.integrate.trapezoid(integrand_JpT_Jp, x=xs, axis=0) - integral_Jo = jscipy.integrate.trapezoid(integrand_JoT_Jo, x=xs, axis=0) - - integral = rho[i] * A[i] * integral_Jp + rho[i] * Ib[i] * integral_Jo - - return integral - - # Compute the cumulative integral - indices = jnp.arange(num_segments) - integrals = vmap(compute_integral)(indices) - - B = jnp.sum(integrals, axis=0) - - return B - - def C_explicit_fn( - params: Dict[str, Array], xi: Array, xi_d: Array - ) -> Tuple[Array, Array]: - """ - Compute the Coriolis / centrifugal matrix of the robot using explicit expression. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - xi_d (Array): velocity vector of the robot. - - Returns: - C (Array): Coriolis / centrifugal matrix of the robot. - """ - - # Extract the parameters - rho = params["rho"] # density of each segment [kg/m^3] - l = params["l"] # length of each segment [m] - r = params["r"] # radius of each segment [m] - - # Usefull derived quantities - A = jnp.pi * r**2 # cross-sectional area of each segment [m^2] - Ib = A**2 / (4 * jnp.pi) # second moment of area of each segment [m^4] - - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Compute each integral - def compute_integral_C(i): - if integration_type == "gauss-legendre": - Xs, Ws, nGauss = gauss_quadrature( - N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1] - ) - - J_all, J_d_all = vmap( - lambda s: J_Jd_explicit_local(params, xi, xi_d, s) - )(Xs) # [[Jp1],[Jp2],[Jo]] - J_all = J_all[:, REORDERED_LINES_STRAIN, :] # [[Jo],[Jp1],[Jp2]] - J_d_all = J_d_all[ - :, REORDERED_LINES_STRAIN, : - ] # [[Jo_d],[Jp1_d],[Jp2_d]] - M_a = rho[i] * jnp.diag( - jnp.stack([Ib[i], A[i], A[i]], axis=0) - ) # [[O],[P],[P]] - - integrand_C = vmap( - lambda _J_i, _J_d_i: ( - _J_i.T - @ (adjoint_star_SE2((_J_i @ xi_d)) @ M_a @ _J_i + M_a @ _J_d_i) - ) - )(J_all, J_d_all) - - integral_C = jnp.sum(Ws[:, None, None] * integrand_C, axis=0) - - elif integration_type == "gauss-kronrad": - rule = GaussKronrodRule(order=param_integration) - - def integrand(s): - J, J_d = J_Jd_explicit_local( - params, xi, xi_d, s - ) # [[Jp1],[Jp2],[Jo]] - J = J[REORDERED_LINES_STRAIN, :] # [Jo, Jp1, Jp2] - J_d = J_d[REORDERED_LINES_STRAIN, :] # [Jo_d, Jp1_d, Jp2_d] - M_a = rho[i] * jnp.diag( - jnp.stack([Ib[i], A[i], A[i]], axis=0) - ) # [[O],[P],[P]] - return J.T @ (adjoint_star_SE2((J @ xi_d)) @ M_a @ J + M_a @ J_d) - - integral_C, _, _, _ = rule.integrate( - integrand, l_cum[i], l_cum[i + 1], args=() - ) - - elif integration_type == "trapezoid": - xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration) - - J_all, J_d_all = vmap( - lambda s: J_Jd_explicit_local(params, xi, xi_d, s) - )(xs) # [[Jp1],[Jp2],[Jo]] - J_all = J_all[:, REORDERED_LINES_STRAIN, :] # [[Jo],[Jp1],[Jp2]] - J_d_all = J_d_all[ - :, REORDERED_LINES_STRAIN, : - ] # [[Jo_d],[Jp1_d],[Jp2_d]] - M_a = rho[i] * jnp.diag( - jnp.stack([Ib[i], A[i], A[i]], axis=0) - ) # [[O],[P],[P]] - - integrand_C = vmap( - lambda _J_i, _J_d_i: ( - _J_i.T - @ (adjoint_star_SE2((_J_i @ xi_d)) @ M_a @ _J_i + M_a @ _J_d_i) - ) - )(J_all, J_d_all) - - integral_C = jscipy.integrate.trapezoid(integrand_C, x=xs, axis=0) - - return integral_C - - # Compute the cumulative integral - indices = jnp.arange(num_segments) - integrals = vmap(compute_integral_C)(indices) - - C = jnp.sum(integrals, axis=0) - - return C - - def U_g_fn_xi( - params: Dict[str, Array], xi: Array, eps: float = global_eps - ) -> Array: - """ - Compute the gravity vector of the robot. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - - Returns: - U_g (Array): gravity vector of the robot. - """ - - # Extract the parameters - g = params["g"] # gravity vector [m/s^2] - rho = params["rho"] # density of each segment [kg/m^3] - l = params["l"] # length of each segment [m] - r = params["r"] # radius of each segment [m] - - # Usefull derived quantitie - A = jnp.pi * r**2 # cross-sectional area of each segment [m^2] - - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Compute each integral - def compute_integral(i): - if integration_type == "gauss-legendre": - Xs, Ws, nGauss = gauss_quadrature( - N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1] - ) - chi_s = vmap(lambda s: chi_fn(params, xi, s, eps))(Xs) - p_s = chi_s[:, :2] - integrand = -rho[i] * A[i] * jnp.einsum("ij,j->i", p_s, g) - - # Compute the integral - integral = jnp.sum(Ws * integrand) - - elif integration_type == "gauss-kronrad": - rule = GaussKronrodRule(order=param_integration) - - def integrand(s): - chi_s = chi_fn(params, xi, s, eps) - p_s = chi_s[:2] - return -rho[i] * A[i] * jnp.dot(p_s, g) - - # Compute the integral - integral, _, _, _ = rule.integrate( - integrand, l_cum[i], l_cum[i + 1], args=() - ) - - elif integration_type == "trapezoid": - xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration) - chi_s = vmap(lambda s: chi_fn(params, xi, s, eps))(xs) - p_s = chi_s[:, :2] - integrand = -rho[i] * A[i] * jnp.einsum("ij,j->i", p_s, g) - - # Compute the integral - integral = jscipy.integrate.trapezoid(integrand, x=xs) - - return integral - - # Compute the cumulative integral - indices = jnp.arange(num_segments) - integrals = vmap(compute_integral)(indices) - - U_g = jnp.sum(integrals) - - return U_g - - def G_autodiff_fn( - params: Dict[str, Array], - xi: Array, - ) -> Array: - """ - Compute the gravity vector of the robot - using autodiff. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - - Returns: - G (Array) : gravity vector of the robot. - """ - - G = jacobian(lambda _xi: U_g_fn_xi(params, _xi))(xi) - - return G - - def G_explicit_fn( - params: Dict[str, Array], - xi: Array, - ) -> Array: - """ - Compute the gravity vector of the robot - using explicit expressions. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - xi (Array): strain vector of the robot. - - Returns: - G (Array): gravity vector of the robot. - """ - # Extract the parameters - g = params["g"] # gravity vector [m/s^2] - rho = params["rho"] # density of each segment [kg/m^3] - l = params["l"] # length of each segment [m] - r = params["r"] # radius of each segment [m] - - # Usefull derived quantitie - A = jnp.pi * r**2 # cross-sectional area of each segment [m^2] - - l_cum = jnp.cumsum(jnp.concatenate([jnp.zeros(1), l])) - - # Compute each integral - def compute_integral(i): - if integration_type == "gauss-legendre": - Xs, Ws, nGauss = gauss_quadrature( - N_GQ=param_integration, a=l_cum[i], b=l_cum[i + 1] - ) - - J_all = vmap(lambda s: J_explicit_global(params, xi, s))(Xs) - Jp_all = J_all[:, :2, :] # shape: (nGauss, n_segments, 3, 3) - - # Compute the integrand - integrand = -rho[i] * A[i] * jnp.einsum("ijk,j->ik", Jp_all, g) - - # Multiply each element of integrand by the corresponding weight - weighted_integrand = jnp.einsum("i, ij->ij", Ws, integrand) - - # Compute the integral - integral = jnp.sum( - weighted_integrand, axis=0 - ) # sum over the Gauss points - - elif integration_type == "gauss-kronrad": - rule = GaussKronrodRule(order=param_integration) - - def integrand(s): - J = J_explicit_global(params, xi, s) - Jp = J[:2, :] - return -rho[i] * A[i] * jnp.dot(g, Jp) - - # Compute the integral - integral, _, _, _ = rule.integrate( - integrand, l_cum[i], l_cum[i + 1], args=() - ) - - elif integration_type == "trapezoid": - xs = jnp.linspace(l_cum[i], l_cum[i + 1], param_integration) - - J_all = vmap(lambda s: J_explicit_global(params, xi, s))(xs) - Jp_all = J_all[:, :2, :] # shape: (nGauss, n_segments, 3, 3) - - # Compute the integrand - integrand = -rho[i] * A[i] * jnp.einsum("ijk,j->ik", Jp_all, g) - - # Compute the integral - integral = jnp.sum(integrand, axis=0) - - return integral - - # Compute the cumulative integral - indices = jnp.arange(num_segments) - integrals = vmap(compute_integral)(indices) - - G = jnp.sum(integrals, axis=0) # sum over the segments - - return G - - if jacobian_type == "explicit": - B_fn_xi = B_explicit_fn - C_fn_xi = C_explicit_fn - G_fn_xi = G_explicit_fn - elif jacobian_type == "autodiff": - B_fn_xi = B_autodiff_fn - C_fn_xi = C_autodiff_fn - G_fn_xi = G_autodiff_fn - - def dynamical_matrices_fn( - params: Dict[str, Array], q: Array, q_d: Array, eps: float = global_eps - ) -> Tuple[Array, Array, Array, Array, Array, Array]: - """ - Compute the dynamical matrices of the robot. - - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): configuration vector of the robot. - q_d (Array): velocity vector of the robot. - eps (float, optional): small number to avoid singularities. Defaults to global_eps. - - Returns: - B (Array): mass / inertia matrix of the robot. (shape: (n_q, n_q)) - C (Array): Coriolis / centrifugal matrix of the robot. (shape: (n_q, n_q)) - G (Array): gravity vector of the robot. (shape: (n_q,)) - K (Array): elastic vector of the robot. (shape: (n_q,)) - D (Array): dissipative matrix of the robot. (shape: (n_q, n_q)) - alpha (Array): actuation matrix of the robot. (shape: (n_q, n_tau)) - """ - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - xi_d = B_xi @ q_d - - # Add a small number to the bending strain to avoid singularities - xi_epsed = apply_eps_to_bend_strains(xi, eps) - - # Compute the stiffness matrix - K = stiffness_fn(params, B_xi, formulate_in_strain_space=True) - # Apply the strain basis to the stiffness matrix - K = B_xi.T @ K @ (xi - xi_eq) # evaluate K(xi) = K @ xi - - # Compute the actuation matrix - A = actuation_mapping_fn( - forward_kinematics_fn, jacobian_fn, params, B_xi, q, eps - ) - # Apply the strain basis to the actuation matrix - alpha = A - - # Dissipative matrix - D = params.get("D", jnp.zeros((n_xi, n_xi))) - # Apply the strain basis to the dissipative matrix - D = B_xi.T @ D @ B_xi - - # Mass/inertia matrix - B = B_xi.T @ B_fn_xi(params, xi_epsed) @ B_xi - - # Coriolis matrix - C = B_xi.T @ C_fn_xi(params, xi_epsed, xi_d) @ B_xi - - # Gravitational matrix - G = B_xi.T @ G_fn_xi(params, xi_epsed).squeeze() - - return B, C, G, K, D, alpha - - def kinetic_energy_fn( - params: Dict[str, Array], q: Array, q_d: Array, eps: float = global_eps - ) -> Array: - """ - Compute the kinetic energy of the system. - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): generalized coordinates of shape (n_q, ) - q_d (Array): generalized velocities of shape (n_q, ) - eps (float, optional): small number to avoid singularities (e.g., division by zero). Defaults to global_eps. - Returns: - T (Array): kinetic energy of shape () - """ - - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - # Add a small number to the bending strain to avoid singularities - xi_epsed = apply_eps_to_bend_strains(xi, eps) - - # Compute the inertia matrix - B = B_fn_xi(params, xi_epsed) - - # Kinetic energy - T = (0.5 * q_d.T @ B @ q_d).squeeze() - - return T - - def potential_energy_fn( - params: Dict[str, Array], q: Array, eps: float = global_eps - ) -> Array: - """ - Compute the potential energy of the system. - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): generalized coordinates of shape (n_q, ) - eps (float, optional): small number to avoid singularities (e.g., division by zero). Defaults to global_eps. - Returns: - U (Array): potential energy of shape () - """ - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - # Add a small number to the bending strain to avoid singularities - xi_epsed = apply_eps_to_bend_strains(xi, eps) - - # compute the stiffness matrix - K = stiffness_fn(params, B_xi, formulate_in_strain_space=True) - # elastic energy - U_K = 0.5 * (xi - xi_eq).T @ K @ (xi - xi_eq) # evaluate K(xi) = K @ xi - - # gravitational potential energy - U_G = U_g_fn_xi(params, xi_epsed) - - # total potential energy - U = (U_G + U_K).squeeze() - - return U - - def energy_fn( - params: Dict[str, Array], q: Array, q_d: Array, eps: float = global_eps - ) -> Array: - """ - Compute the total energy of the system. - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): generalized coordinates of shape (n_q, ) - q_d (Array): generalized velocities of shape (n_q, ) - eps (float, optional): small number to avoid singularities (e.g., division by zero). Defaults to global_eps. - Returns: - E (Array): total energy of shape () - """ - T = kinetic_energy_fn(params, q, q_d, eps) - U = potential_energy_fn(params, q, eps) - E = T + U - - return E - - def operational_space_dynamical_matrices_fn( - params: Dict[str, Array], - q: Array, - q_d: Array, - s: Array, - B: Array, - C: Array, - operational_space_selector: Tuple = (True, True, True), - eps: float = global_eps, - ) -> Tuple[Array, Array, Array, Array, Array]: - """ - Compute the dynamics in operational space. - The implementation is based on Chapter 7.8 of "Modelling, Planning and Control of Robotics" by - Siciliano, Sciavicco, Villani, Oriolo. - Args: - params (Dict[str, Array]): dictionary of robot parameters. - q (Array): generalized coordinates of shape (n_q, ) - q_d (Array): generalized velocities of shape (n_q, ) - s (Array): point coordinate along the robot in the interval [0, L]. - B (Array): mass / inertia matrix of the robot in the generalized coordinates of shape (n_q, n_q) - C (Array): Coriolis / centrifugal matrix of the robot in the generalized coordinates of shape (n_q, n_q) - operational_space_selector (Tuple, optional): tuple of shape (3,) to select the operational space variables. - For example, (True, True, False) selects only the positional components of the operational space. - eps (float, optional): small number to avoid singularities (e.g., division by zero). Defaults to global_eps. - Returns: - Lambda (Array): inertia matrix in the operational space of shape (3, 3) - mu (Array): coriolis and centrifugal matrix in the operational space of shape (3, ) - J (Array): inertial-frame jacobian of the end-effector pose with respect to the generalized coordinates - J_d (Array): inertial-frame time-derivative of the jacobian of the end-effector pose with respect to the generalized coordinates - of shape (3, n_q) - JB_pinv (Array): Dynamically-consistent pseudo-inverse of the inertial-frame jacobian. Allows the mapping of torques - from the generalized coordinates to the operational space: f = JB_pinv.T @ tau_q - Shape (n_q, 3) - """ - # Map the configuration to the strains - xi = xi_eq + B_xi @ q - xi_d = B_xi @ q_d - - # classify the point along the robot to the corresponding segment - _, s_segment, _ = classify_segment(params, s) - - # make operational_space_selector a boolean array - operational_space_selector = onp.array(operational_space_selector, dtype=bool) - - # Jacobian and its time-derivative - J, J_d = J_Jd(params, xi, xi_d, s_segment, eps) - J = jnp.squeeze(J) - J_d = jnp.squeeze(J_d) - - J = J[operational_space_selector, :] - J_d = J_d[operational_space_selector, :] - - # inverse of the inertia matrix in the configuration space - B_inv = jnp.linalg.inv(B) - - Lambda = jnp.linalg.inv( - J @ B_inv @ J.T - ) # inertia matrix in the operational space - mu = Lambda @ ( - J @ B_inv @ C - J_d - ) # coriolis and centrifugal matrix in the operational space - - JB_pinv = ( - B_inv @ J.T @ Lambda - ) # dynamically-consistent pseudo-inverse of the Jacobian - - return Lambda, mu, J, J_d, JB_pinv - - auxiliary_fns: Dict[str, Callable[..., Any]] = { - "apply_eps_to_bend_strains": apply_eps_to_bend_strains, - "classify_segment": classify_segment, - "stiffness_fn": stiffness_fn, - "actuation_mapping_fn": actuation_mapping_fn, - "jacobian_fn": jacobian_fn, - "kinetic_energy_fn": kinetic_energy_fn, - "potential_energy_fn": potential_energy_fn, - "energy_fn": energy_fn, - "operational_space_dynamical_matrices_fn": operational_space_dynamical_matrices_fn, - } - - return B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns diff --git a/src/jsrm/systems/utils.py b/src/jsrm/systems/utils.py index cd6d42b..d2911c5 100644 --- a/src/jsrm/systems/utils.py +++ b/src/jsrm/systems/utils.py @@ -93,7 +93,7 @@ def compute_strain_basis( strain_selector: Array, ) -> Array: """ - Compute strain basis based on boolean strain selector. + Compute constant strain basis based on boolean strain selector. Args: strain_selector (Array): boolean array of shape (n_xi, ) specifying which strain components are active @@ -132,6 +132,39 @@ def compute_planar_stiffness_matrix( return S +def compute_spatial_stiffness_matrix( + l: Array, A: Array, Ib: Array, J: Array, E: Array, G: Array +) -> Array: + """ + Compute the stiffness matrix of the system. + Args: + l: length of the segment of shape () + A: cross-sectional area of shape () + Ib: second moment of area of shape () + J: polar moment of inertia of shape () + E: Elastic modulus of shape () + G: Shear modulus of shape () + + Returns: + S: stiffness matrix of shape (3, 3) + """ + S = l * jnp.diag( + jnp.stack( + [ + E * Ib, # bending X + E * Ib, # bending Y + G * J, # torsion Z + 4 / 3 * A * G, # shear X (approx.) + 4 / 3 * A * G, # shear Y (approx.) + A * E, # axial Z + ], + axis=0, + ) + ) + + return S + + def gauss_quadrature(N_GQ: int, a=0.0, b=1.0) -> Tuple[Array, Array, int]: """ Computes the Legendre-Gauss nodes and weights on the interval [0, 1] @@ -202,3 +235,24 @@ def convergence_condition(y): Ws = jnp.concatenate([jnp.array([0.0]), Ws, jnp.array([0.0])]) return Xs, Ws, N_GQ + 2 + + +def scale_gaussian_quadrature( + Xs: Array, Ws: Array, a: float = 0.0, b: float = 1.0 +) -> Tuple[Array, Array]: + """ + Scale the Gauss nodes and weights from [0, 1] to the interval [a, b]. + + Args: + Xs (Array): The Gauss nodes on [0, 1]. + Ws (Array): The Gauss weights on [0, 1]. + a (float): The lower bound of the interval. + b (float): The upper bound of the interval. + + Returns: + Xs_scaled (Array): The scaled Gauss nodes on [a, b]. + Ws_scaled (Array): The scaled Gauss weights on [a, b]. + """ + Xs_scaled = a + (b - a) * Xs + Ws_scaled = Ws * (b - a) + return Xs_scaled, Ws_scaled diff --git a/src/jsrm/utils/lie_algebra.py b/src/jsrm/utils/lie_algebra.py new file mode 100644 index 0000000..3d985ff --- /dev/null +++ b/src/jsrm/utils/lie_algebra.py @@ -0,0 +1,801 @@ +import jax.numpy as jnp +from jax import lax + +# for documentation +from jax import Array + +# ================================================================================================ +# SE(2) operators +# =================================== +J = jnp.array([[0, -1], [1, 0]]) + + +def hat_SE2(vec3: Array) -> Array: + """ + Computes the hat operator for a 3D vector of se(2). + + Args: + vec3 (Array): array-like, shape (3,1) + A 3-dimensional vector representing the screw. + The first element correspond to the angular component, + and the last two elements correspond to the linear components. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the hat operator of the input screw vector. + """ + vec3 = vec3.reshape(-1) # Ensure vec3 is a 1D array + + ang = vec3[0] # Angular part + lin = vec3[1:].reshape((2, 1)) # Linear as a (2,1) vector + + angtilde = ang * J + + hat = jnp.block([[angtilde, lin], [jnp.zeros((1, 2)), jnp.zeros((1, 1))]]) + + return hat + + +def exp_SE2(vec3: Array) -> Array: + """ + Computes the exponential map for a 3D vector of se(2). + + Args: + vec3 (Array): array-like, shape (3,1) + A 3-dimensional vector representing the position. + [theta, x, y] where theta is the rotation angle and (x, y) is the translation vector. + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the exponential map of the input screw vector. + """ + vec3 = vec3.reshape(-1) # Ensure vec3 is a 1D array + + theta = vec3[0] + p = vec3[1:].reshape((2, 1)) + + cos = jnp.cos(theta) + sin = jnp.sin(theta) + R = jnp.array([[cos, -sin], [sin, cos]]) # Rotation matrix + + g = jnp.block([[R, p], [jnp.zeros((1, 2)), jnp.ones((1, 1))]]) + + return g + + +def adjoint_se2(vec3: Array) -> Array: + """ + Computes the adjoint representation of a vector of se(2). + + Args: + vec3 (Array): array-like, shape (3, 1) + A 3-dimensional vector representing the screw. + The first element correspond to the angular component, + and the last two elements correspond to the linear component. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the adjoint transformation of the input screw vector. + """ + vec3 = vec3.reshape(-1) # Ensure vec6 is a 1D array + + ang = vec3[0] + lin = vec3[1:].reshape((2, 1)) # Linear as a (3,1) vector + + adj = jnp.concatenate( + [jnp.zeros((1, 3)), jnp.concatenate([-J @ lin, ang * J], axis=1)] + ) + + return adj + + +def coadjoint_se2(vec3: Array) -> Array: + """ + Computes the co-adjoint representation of a vector of se(2). + + Args: + vec3 (Array): array-like, shape (3, 1) + A 3-dimensional vector representing the screw. + The first element correspond to the angular component, + and the last two elements correspond to the linear component. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the co-adjoint transformation of the input screw vector. + """ + vec3 = vec3.reshape(-1) # Ensure vec6 is a 1D array + + ang = vec3[0] + lin = vec3[1:].reshape((2, 1)) # Linear as a (3,1) vector + + adj_star = jnp.concatenate( + [jnp.zeros((3, 1)), jnp.concatenate([lin.T @ J, ang * J], axis=0)], axis=1 + ) + + return adj_star + + +def Adjoint_g_SE2(mat3: Array) -> Array: + """ + Computes the adjoint representation of a 3x3 matrix. + + Args: + mat4 (Array): array-like, shape (4,4) + A 4x4 matrix representing the transformation. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the Adjoint transformation of the input matrix. + """ + R = mat3[:2, :2] # Extract the angular part (top-left 2x2 block) + t = mat3[:2, 2].reshape((2, 1)) # Extract the linear part (top-right column) + + Adjoint = jnp.concatenate( + [ + jnp.concatenate([jnp.ones(((1, 1))), jnp.zeros((1, 2))], axis=1), + jnp.concatenate([-J @ t, R], axis=1), + ] + ) + + return Adjoint + + +def Adjoint_g_inv_SE2(mat3: Array) -> Array: + """ + Computes the adjoint representation of a 3x3 matrix. + + Args: + mat4 (Array): array-like, shape (4,4) + A 4x4 matrix representing the transformation. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the Adjoint transformation of the input matrix. + """ + Adj = Adjoint_g_SE2(mat3) # Adjoint representation of the input matrix + + # Extract R and -Jt from the Adjoint matrix + R = Adj[1:, 1:] + mJt = Adj[1:, 0].reshape(-1, 1) + + # Compute the inverse using the Schur complement + R_inv = jnp.transpose(R) # Since R is a rotation matrix, R^-1=R^T + # Construct the inverse Adjoint matrix + inverse_Adjoint = jnp.concatenate( + [ + jnp.concatenate([jnp.ones(((1, 1))), jnp.zeros((1, 2))], axis=1), + jnp.concatenate([-R_inv @ mJt, R_inv], axis=1), + ] + ) + + return inverse_Adjoint + + +def Adjoint_gi_se2( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the adjoint representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(2) deformed ine the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (3,1) + A 3-dimensional vector representing the screw in the current segment. + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the current segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the adjoint transformation of the input screw vector at the specified position. + """ + # We suppose here that theta is not zero thanks to a previous use of apply_eps + theta = xi_i[0] # Angular part + adjoint_xi_i = adjoint_se2(xi_i) # Adjoint representation of the input vector + + cos = jnp.cos(s_i * theta) + sin = jnp.sin(s_i * theta) + + Adjoint = lax.cond( + jnp.abs(theta) <= eps, + lambda _: jnp.eye(3) + s_i * adjoint_xi_i, # Avoid division by zero + lambda _: ( + jnp.eye(3) + + 1 / (2 * theta) * (3 * sin - s_i * theta * cos) * adjoint_xi_i + + 1 + / (2 * jnp.power(theta, 2)) + * (4 - 4 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 2) + + 1 + / (2 * jnp.power(theta, 3)) + * (sin - s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 3) + + 1 + / (2 * jnp.power(theta, 4)) + * (2 - 2 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 4) + ), + operand=None, + ) + + return Adjoint + + +def Adjoint_gi_se2_inv( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the adjoint representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(2) deformed ine the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (3,1) + A 3-dimensional vector representing the screw in SE(2). + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the adjoint transformation of the input screw vector at the specified position. + """ + Adj = Adjoint_gi_se2( + xi_i, s_i, eps=eps + ) # Adjoint representation of the input vector + + # Extract R and -Jt from the Adjoint matrix + R = Adj[1:, 1:] + mJt = Adj[1:, 0].reshape(-1, 1) + + # Compute the inverse using the Schur complement + R_inv = jnp.transpose(R) # Since R is a rotation matrix, R^-1=R^T + # Construct the inverse Adjoint matrix + inverse_Adjoint = jnp.concatenate( + [ + jnp.concatenate([jnp.ones(((1, 1))), jnp.zeros((1, 2))], axis=1), + jnp.concatenate([-R_inv @ mJt, R_inv], axis=1), + ] + ) + + return inverse_Adjoint + + +def Tangent_gi_se2( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the tangent representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(2) deformed in the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (3,1) + A 3-dimensional vector representing the screw in SE(2). + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the tangent transformation of the input screw vector at the specified position. + """ + # We suppose here that theta is not zero thanks to a previous use of apply_eps + theta = xi_i[0] # Angular part + adjoint_xi_i = adjoint_se2(xi_i) # Adjoint representation of the input vector + + cos = jnp.cos(s_i * theta) + sin = jnp.sin(s_i * theta) + + Tangent = lax.cond( + jnp.abs(theta) <= eps, + lambda _: s_i * jnp.eye(3) + s_i**2 / 2 * adjoint_xi_i, + lambda _: ( + s_i * jnp.eye(3) + + 1 + / (2 * jnp.power(theta, 2)) + * (4 - 4 * cos - s_i * theta * sin) + * adjoint_xi_i + + 1 + / (2 * jnp.power(theta, 3)) + * (4 * s_i * theta - 5 * sin + s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 2) + + 1 + / (2 * jnp.power(theta, 4)) + * (2 - 2 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 3) + + 1 + / (2 * jnp.power(theta, 5)) + * (2 * s_i * theta - 3 * sin + s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 4) + ), + operand=None, + ) + + return Tangent + + +# ================================================================================================ +# SE(3) operators +# =================================== +def tilde_SE3(vec3: Array) -> Array: + """ + Computes the tilde operator of SE(3) for a 3D vector. + + Args: + vec3 (Array): array-like, shape (3,1) + A 3-dimensional vector. + + Returns: + Array: shape (3, 3) + A 3x3 matrix representing the tilde operator of the input vector. + """ + vec3 = vec3.reshape(-1) # Ensure vec3 is a 1D array + + # Extract components of the vector + x, y, z = vec3.flatten() + + # Construct the tilde operator + tilde = jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]]) + return tilde + + +def hat_SE3(vec6: Array) -> Array: + """ + Computes the hat operator for a 6D vector of se(3). + + Args: + vec6 (Array): array-like, shape (6,1) + A 6-dimensional vector representing the screw. + The first three elements correspond to the angular component, + and the last three elements correspond to the linear components. + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the hat operator of the input screw vector. + """ + vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array + + ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector + lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector + + angtilde = tilde_SE3(ang) # Tilde operator for angular part + + hat = jnp.block([[angtilde, lin], [jnp.zeros((1, 3)), jnp.zeros((1, 1))]]) + + return hat + + +def exp_SE3(vec6: Array) -> Array: + """ + Computes the exponential map for a 6D vector of se(3). + + Args: + vec6 (Array): array-like, shape (6,1) + A 6-dimensional vector representing the position. + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the exponential map of the input screw vector. + """ + vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array + + phi = vec6[0] + theta = vec6[1] + psi = vec6[2] + cosphi, sinphi = jnp.cos(phi), jnp.sin(phi) + costheta, sintheta = jnp.cos(theta), jnp.sin(theta) + cospsi, sinpsi = jnp.cos(psi), jnp.sin(psi) + + p = vec6[3:].reshape((3, 1)) + + Rphi = jnp.array([[cosphi, -sinphi, 0], [sinphi, cosphi, 0], [0, 0, 1]]) + Rtheta = jnp.array([[1, 0, 0], [0, costheta, -sintheta], [0, sintheta, costheta]]) + Rpsi = jnp.array([[cospsi, -sinpsi, 0], [sinpsi, cospsi, 0], [0, 0, 1]]) + # Combine the rotations + R = Rpsi @ Rtheta @ Rphi # Rotation matrix + + g = jnp.block([[R, p], [jnp.zeros((1, 3)), jnp.ones((1, 1))]]) + + return g + + +def log_SE3(g: Array, eps: float) -> Array: + """ + Computes the logarithm map from SE(3) to se(3), i.e., extracts the twist from a transformation matrix. + + Args: + g (Array): array-like, shape (4, 4) + A transformation matrix in SE(3). + eps (float): tolerance to avoid division by zero in small angle approximations. + + Returns: + Array: shape (6,) + A 6D vector (twist) representing the logarithm of the transformation. + """ + R = g[:3, :3] + p = g[:3, 3].reshape((3, 1)) + + # Compute the rotation angle + trace_R = jnp.trace(R) + cos_theta = (trace_R - 1) / 2 + cos_theta = jnp.clip(cos_theta, -1.0, 1.0) # For numerical stability + theta = jnp.arccos(cos_theta) + + # Logarithm of R + omega_hat = lax.cond( + jnp.abs(theta) < eps, + lambda _: jnp.zeros((3, 3)), + lambda _: (theta / (2 * jnp.sin(theta))) * (R - R.T), + operand=None, + ) + + omega = jnp.array([omega_hat[2, 1], omega_hat[0, 2], omega_hat[1, 0]]).reshape( + (3, 1) + ) + + # Compute V inverse (Jacobian inverse) + omega_tilde = omega_hat + + def compute_V_inv(theta): + A = jnp.eye(3) - 0.5 * omega_tilde + B = (1 / (theta**2)) * ( + 1 - (theta * jnp.sin(theta)) / (2 * (1 - jnp.cos(theta))) + ) + V_inv = A + B * (omega_tilde @ omega_tilde) + return V_inv + + V_inv = lax.cond( + jnp.abs(theta) < eps, + lambda _: jnp.eye(3), + lambda _: compute_V_inv(theta), + operand=None, + ) + + v = V_inv @ p + + return jnp.vstack([omega, v]).reshape(-1) + + +def exp_gn_SE3(vec6: Array, eps: float) -> Array: + """ + Function to compute the exponential map of the Magnus expansion. + + Args: + vec6 (Array): shape (6,) JAX array + The screw vector representing the Magnus expansion. + + Returns: + g (Array): shape (4, 4) JAX array + The exponential map of the Magnus expansion. + """ + theta = jnp.linalg.norm(vec6[:3]) # Compute the norm of the angular part + vec6_hat = hat_SE3(vec6) # Compute the hat + + costheta = jnp.cos(theta) + sintheta = jnp.sin(theta) + + g = lax.cond( + theta <= eps, + lambda _: ( + jnp.eye(4) # Avoid division by zero + + vec6_hat + + 1 / 2 * jnp.linalg.matrix_power(vec6_hat, 2) + + 1 / 6 * jnp.linalg.matrix_power(vec6_hat, 3) + ), + lambda _: ( + jnp.eye(4) + + vec6_hat + + 1 + / jnp.power(theta, 2) + * (1 - costheta) + * jnp.linalg.matrix_power(vec6_hat, 2) + + 1 + / jnp.power(theta, 3) + * (theta - sintheta) + * jnp.linalg.matrix_power(vec6_hat, 3) + ), + operand=None, + ) + + return g + + +def adjoint_se3(vec6: Array) -> Array: + """ + Computes the adjoint representation of a vector of se(3). + + Args: + vec6 (Array): array-like, shape (3, 1) + A 6-dimensional vector representing the screw. + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the adjoint transformation of the input screw vector. + """ + vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array + + ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector + lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector + + angtilde = tilde_SE3(ang) # Tilde operator for angular part + lintilde = tilde_SE3(lin) # Tilde operator for linear part + + adj = jnp.block([[angtilde, jnp.zeros((3, 3))], [lintilde, angtilde]]) + + return adj + + +def coadjoint_se3(vec6: Array) -> Array: + """ + Computes the co-adjoint representation of a vector of se(3). + + Args: + vec6 (Array): array-like, shape (3, 1) + A 6-dimensional vector representing the screw. + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the co-adjoint transformation of the input screw vector. + """ + vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array + + ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector + lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector + + angtilde = tilde_SE3(ang) # Tilde operator for angular part + lintilde = tilde_SE3(lin) # Tilde operator for linear part + + adj_star = jnp.block([[angtilde, lintilde], [jnp.zeros((3, 3)), angtilde]]) + + return adj_star + + +def Adjoint_g_SE3(mat4: Array) -> Array: + """ + Computes the adjoint representation of a 4x4 matrix. + + Args: + mat4 (Array): array-like, shape (4,4) + A 4x4 matrix representing the transformation. + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the Adjoint transformation of the input matrix. + """ + R = mat4[:3, :3] # Extract the angular part (top-left 3x3 block) + t = mat4[:3, 3].reshape((3, 1)) # Extract the linear part (top-right column) + + ttilde = tilde_SE3(t) # Tilde operator for linear part + + Adjoint = jnp.block([[R, jnp.zeros((3, 3))], [ttilde @ R, R]]) + + return Adjoint + + +def Adjoint_g_inv_SE3(mat4: Array) -> Array: + """ + Computes the adjoint representation of a 4x4 matrix. + + Args: + mat4 (Array): array-like, shape (4,4) + A 4x4 matrix representing the transformation. + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the Adjoint transformation of the input matrix. + """ + R = mat4[:3, :3] # Extract the angular part (top-left 3x3 block) + t = mat4[:3, 3].reshape((3, 1)) # Extract the linear part (top-right column) + + ttilde = tilde_SE3(t) # Tilde operator for linear part + R_inv = jnp.transpose(R) # Since R is a rotation matrix, R^-1=R^T + + # Construct the inverse Adjoint matrix + inverse_Adjoint = jnp.block([[R_inv, jnp.zeros((3, 3))], [-R_inv @ ttilde, R_inv]]) + + return inverse_Adjoint + + +def Adjoint_gi_se3( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the adjoint representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(3) deformed ine the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (6,1) + A 6-dimensional vector representing the screw in the current segment. + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the current segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the adjoint transformation of the input screw vector at the specified position. + """ + # We suppose here that theta is not zero thanks to a previous use of apply_eps + ang = xi_i[:3].reshape((3, 1)) # Angular as a (3,1) vector + theta = jnp.linalg.norm(ang) # Compute the norm of the angular part + adjoint_xi_i = adjoint_se3(xi_i) # Adjoint representation of the input vector + + cos = jnp.cos(s_i * theta) + sin = jnp.sin(s_i * theta) + + Adjoint = lax.cond( + jnp.abs(theta) <= eps, + lambda _: jnp.eye(6) + s_i * adjoint_xi_i, # Avoid division by zero + lambda _: ( + jnp.eye(6) + + 1 / (2 * theta) * (3 * sin - s_i * theta * cos) * adjoint_xi_i + + 1 + / (2 * jnp.power(theta, 2)) + * (4 - 4 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 2) + + 1 + / (2 * jnp.power(theta, 3)) + * (sin - s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 3) + + 1 + / (2 * jnp.power(theta, 4)) + * (2 - 2 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 4) + ), + operand=None, + ) + + return Adjoint + + +def Adjoint_gi_se3_inv( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the adjoint representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(3) deformed ine the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (6,1) + A 6-dimensional vector representing the screw in SE(3). + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the adjoint transformation of the input screw vector at the specified position. + """ + Adj = Adjoint_gi_se3( + xi_i, s_i, eps=eps + ) # Adjoint representation of the input vector + + # Extract R and -Jt from the Adjoint matrix + R = Adj[:3, :3] + ttildeR = Adj[3:, :3] + + # Compute the inverse using the Schur complement + R_inv = jnp.transpose(R) # Since R is a rotation matrix, R^-1=R^T + ttilde = ttildeR @ R_inv # Compute the tilde operator for the linear part + # Construct the inverse Adjoint matrix + inverse_Adjoint = jnp.block([[R_inv, jnp.zeros((3, 3))], [-R_inv @ ttilde, R_inv]]) + + return inverse_Adjoint + + +def Tangent_gi_se3( + xi_i: Array, + s_i: float, + eps: float, +) -> Array: + """ + Computes the tangent representation of a position of a points at s_i (local curvilinear coordinate) + along a rod in SE(3) deformed in the current segment according to a strain vector xi_i. + + Args: + xi_i (Array): array-like, shape (6,1) + A 6-dimensional vector representing the screw in SE(3). + The first three elements correspond to the angular component, + and the last three elements correspond to the linear component. + s_i (float): + The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. + eps (float): small value to avoid division by zero + + Returns: + Array: shape (4, 4) + A 4x4 matrix representing the tangent transformation of the input screw vector at the specified position. + """ + # We suppose here that theta is not zero thanks to a previous use of apply_eps + ang = xi_i[:3].reshape((3, 1)) # Angular as a (3,1) vector + theta = jnp.linalg.norm(ang) # Compute the norm of the angular part + adjoint_xi_i = adjoint_se3(xi_i) # Adjoint representation of the input vector + + cos = jnp.cos(s_i * theta) + sin = jnp.sin(s_i * theta) + + Tangent = lax.cond( + jnp.abs(theta) <= eps, + lambda _: s_i * jnp.eye(6) + s_i**2 / 2 * adjoint_xi_i, + lambda _: ( + s_i * jnp.eye(6) + + 1 + / (2 * jnp.power(theta, 2)) + * (4 - 4 * cos - s_i * theta * sin) + * adjoint_xi_i + + 1 + / (2 * jnp.power(theta, 3)) + * (4 * s_i * theta - 5 * sin + s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 2) + + 1 + / (2 * jnp.power(theta, 4)) + * (2 - 2 * cos - s_i * theta * sin) + * jnp.linalg.matrix_power(adjoint_xi_i, 3) + + 1 + / (2 * jnp.power(theta, 5)) + * (2 * s_i * theta - 3 * sin + s_i * theta * cos) + * jnp.linalg.matrix_power(adjoint_xi_i, 4) + ), + operand=None, + ) + + return Tangent + + +# ================================================================================================ +# Shared operators +# ============================ +def compute_weighted_sums(M: Array, vecm: Array, idx: int) -> Array: + """ + Compute the weighted sums of the matrix product of M and vecm, + + Args: + M (Array): array of shape (N, m, m) + Describes the matrix to be multiplied with vecm + vecm (Array): array-like of shape (N, m) + Describes the vector to be multiplied with M + idx (int): index of the last row to be summed over + + Returns: + Array: array of shape (N, m) + The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx. + """ + N = M.shape[0] + # Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m) + prod = jnp.einsum("nij,nj->ni", M, vecm) + + # Triangular mask for partial sum: (N, N) + # mask[i, j] = 1 if j >= i and j <= idx + mask = (jnp.arange(N)[:, None] <= jnp.arange(N)[None, :]) & ( + jnp.arange(N)[None, :] <= idx + ) + mask = mask.astype(M.dtype) # (N, N) + + # Extend 6-dimensional mask (N, N, 1) to apply to (N, m) + masked_prod = mask[:, :, None] * prod[None, :, :] # (N, N, m) + + # Sum over j for each i : (N, m) + result = masked_prod.sum(axis=1) # (N, m) + return result + + +if __name__ == "__main__": + vec6 = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + exp_SE3_result = exp_SE3(vec6) + print("Exponential map of SE(3):") + print(exp_SE3_result) diff --git a/src/jsrm/utils/lie_operators.py b/src/jsrm/utils/lie_operators.py deleted file mode 100644 index 4ca7dca..0000000 --- a/src/jsrm/utils/lie_operators.py +++ /dev/null @@ -1,608 +0,0 @@ -import jax.numpy as jnp - -# for documentation -from jax import Array -from typing import Sequence - - -def tilde_SE3(vec3: Array) -> Array: - """ - Computes the tilde operator of SE(3) for a 3D vector. - - Args: - vec3 (Array): array-like, shape (3,1) - A 3-dimensional vector. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the tilde operator of the input vector. - """ - vec3 = vec3.reshape(-1) # Ensure vec3 is a 1D array - - # Extract components of the vector - x, y, z = vec3.flatten() - - # Use JAX's array creation for better performance - Mtilde = jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]]) - return Mtilde - - -def adjoint_SE3(vec6: Array) -> Array: - """ - Computes the adjoint representation of a vector of se(3). - - Args: - vec6 (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw. - The first three elements correspond to the angular component, - and the last three elements correspond to the linear component. - - Returns: - Array: shape (6, 6) - A 6x6 matrix representing the adjoint transformation of the input screw vector. - """ - vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array - - ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector - lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector - - angtilde = tilde_SE3(ang) # Tilde operator for angular part - lintilde = tilde_SE3(lin) # Tilde operator for linear part - - adj = jnp.block([[angtilde, jnp.zeros((3, 3))], [lintilde, angtilde]]) - - return adj - - -def adjoint_star_SE3(vec6: Array) -> Array: - """ - Computes the co-adjoint representation of a vector of se(3). - - Args: - vec6 (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw. - The first three elements correspond to the angular component, - and the last three elements correspond to the linear component. - - Returns: - Array: shape (6, 6) - A 6x6 matrix representing the co-adjoint transformation of the input screw vector. - """ - vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array - - ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector - lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector - - angtilde = tilde_SE3(ang) # Tilde operator for angular part - lintilde = tilde_SE3(lin) # Tilde operator for linear part - - adj_star = jnp.block([[angtilde, lintilde], [jnp.zeros((3, 3)), angtilde]]) - - return adj_star - - -def hat_SE3(vec6: Array) -> Array: - """ - Computes the hat operator for a 6D vector of se(3). - - Args: - vec6 (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw. - The first three elements correspond to the angular component, - and the last three elements correspond to the linear component. - - Returns: - Array: shape (4, 4) - A 4x4 matrix representing the hat operator of the input screw vector. - """ - vec6 = vec6.reshape(-1) # Ensure vec6 is a 1D array - - ang = vec6[:3].reshape((3, 1)) # Angular as a (3,1) vector - lin = vec6[3:].reshape((3, 1)) # Linear as a (3,1) vector - - angtilde = tilde_SE3(ang) # Tilde operator for angular part - - hat = jnp.block([[angtilde, lin], [jnp.zeros((1, 3)), jnp.zeros((1, 1))]]) - - return hat - - -def Adjoint_g_SE3(mat4: Array) -> Array: - """ - Computes the adjoint representation of a 4x4 matrix. - - Args: - mat4 (Array): array-like, shape (4,4) - A 4x4 matrix representing the transformation. - - Returns: - Array: shape (4, 4) - A 4x4 matrix representing the Adjoint transformation of the input matrix. - """ - ang = mat4[:3, :3] # Extract the angular part (top-left 3x3 block) - lin = mat4[:3, 3].reshape((3, 1)) # Extract the linear part (top-right column) - - ltilde = tilde_SE3(lin) # Tilde operator for linear part - - Adjoint = jnp.block([[ang, jnp.zeros((3, 3))], [ltilde @ ang, ang]]) - - return Adjoint - - -def Adjoint_gn_SE3( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the adjoint representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(3) deformed ine the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the lenght from the origin of the rod to the begining of the n-th segment. - - Args: - xi_n (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw in SE(3). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (6, 6) - A 6x6 matrix representing the adjoint transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - ang = xi_n[:3].reshape((3, 1)) # Angular as a (3,1) vector - theta = jnp.linalg.norm(ang) # Compute the norm of the angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE3(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - Adjoint = ( - jnp.eye(6) - + 1 / (2 * theta) * (3 * sin - x * theta * cos) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 2)) - * (4 - 4 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 3)) - * (sin - x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - return Adjoint - - -def Adjoint_gn_SE3_inv( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the adjoint representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(3) deformed ine the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the lenght from the origin of the rod to the begining of the n-th segment. - - Args: - xi_n (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw in SE(3). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (6, 6) - A 6x6 matrix representing the adjoint transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - ang = xi_n[:3].reshape((3, 1)) # Angular as a (3,1) vector - theta = jnp.linalg.norm(ang) # Compute the norm of the angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE3(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - - Adjoint = ( - jnp.eye(6) - + 1 / (2 * theta) * (3 * sin - x * theta * cos) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 2)) - * (4 - 4 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 3)) - * (sin - x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - # Extract R and uR from the Adjoint matrix - R = Adjoint[:3, :3] - uR = Adjoint[3:, :3] - - # Compute the inverse using the Schur complement - R_inv = jnp.transpose(R) # Since R is a rotation matrix - u = jnp.dot(uR, R_inv) # Compute the linear part - uR_inv = -jnp.dot(R_inv, u) # Compute the inverse linear part - - # Construct the inverse Adjoint matrix - inverse_Adjoint = jnp.block([[R_inv, jnp.zeros((3, 3))], [uR_inv, R_inv]]) - - return inverse_Adjoint - - -def Tangent_gn_SE3( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the tangent representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(3) deformed in the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the length from the origin of the rod to the beginning of the n-th segment. - - Args: - xi_n (Array): array-like, shape (6,1) - A 6-dimensional vector representing the screw in SE(3). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (6, 6) - A 6x6 matrix representing the tangent transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - ang = xi_n[:3].reshape((3, 1)) # Angular as a (3,1) vector - theta = jnp.linalg.norm(ang) # Compute the norm of the angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE3(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - - Tangent = ( - x * jnp.eye(6) - + 1 / (2 * jnp.power(theta, 2)) * (4 - 4 * cos - x * theta * sin) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 3)) - * (4 * x * theta - 5 * sin + x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 5)) - * (2 * x * theta - 3 * sin + x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - return Tangent - - -def vec_SE2_to_xi_SE3(vec3: Array, indices: Sequence[int] = (2, 3, 4)) -> Array: - """ - Convert a strain vector in se(2) to a strain vector in se(3). - - Args: - vec3 (Array): array-like, shape (3,1) - A 3-dimensional vector representing the strain in se(2). - The first element correspond to the angular component, - and the last elements corresponds to the linear component. - indices (Sequence[int], optional): Indices in the 6D se(3) vector - where to insert the se(2) components. Default is (2, 3, 4) - - Returns: - Array: shape (6,1) - A 6-dimensional vector representing the strain in se(3). - The first three elements correspond to the angular component, - and the last three elements correspond to the linear component. - """ - vec3 = jnp.asarray(vec3).flatten() # Ensure vec3 is a JAX array - - xi = jnp.zeros((6,)) # Initialize a 6D vector with zeros - xi = xi.at[jnp.array(indices)].set(vec3) # Set the values at the specified indices - return xi.reshape((6, 1)) - - -# ================================================================================================ -# SE(2) operators -# =================================== -J = jnp.array([[0, -1], [1, 0]]) - - -def adjoint_SE2(vec3: Array) -> Array: - """ - Computes the adjoint representation of a vector of se(2). - - Args: - vec3 (Array): array-like, shape (3, 1) - A 3-dimensional vector representing the screw. - The first element correspond to the angular component, - and the last two elements correspond to the linear component. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the adjoint transformation of the input screw vector. - """ - vec3 = vec3.reshape(-1) # Ensure vec6 is a 1D array - - ang = vec3[0] - lin = vec3[1:].reshape((2, 1)) # Linear as a (3,1) vector - - adj = jnp.concatenate( - [jnp.zeros((1, 3)), jnp.concatenate([-J @ lin, ang * J], axis=1)] - ) - - return adj - - -def adjoint_star_SE2(vec3: Array) -> Array: - """ - Computes the co-adjoint representation of a vector of se(2). - - Args: - vec3 (Array): array-like, shape (3, 1) - A 3-dimensional vector representing the screw. - The first element correspond to the angular component, - and the last two elements correspond to the linear component. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the co-adjoint transformation of the input screw vector. - """ - vec3 = vec3.reshape(-1) # Ensure vec6 is a 1D array - - ang = vec3[0] - lin = vec3[1:].reshape((2, 1)) # Linear as a (3,1) vector - - adj_star = jnp.concatenate( - [jnp.zeros((3, 1)), jnp.concatenate([lin.T @ J, ang * J], axis=0)], axis=1 - ) - - return adj_star - - -def Adjoint_g_SE2(mat3: Array) -> Array: - """ - Computes the adjoint representation of a 3x3 matrix. - - Args: - mat4 (Array): array-like, shape (4,4) - A 4x4 matrix representing the transformation. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the Adjoint transformation of the input matrix. - """ - R = mat3[:2, :2] # Extract the angular part (top-left 2x2 block) - t = mat3[:2, 2].reshape((2, 1)) # Extract the linear part (top-right column) - - Adjoint = jnp.concatenate( - [ - jnp.concatenate([jnp.ones(((1, 1))), jnp.zeros((1, 2))], axis=1), - jnp.concatenate([-J @ t, R], axis=1), - ] - ) - - return Adjoint - - -def Adjoint_gn_SE2( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the adjoint representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(2) deformed ine the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the lenght from the origin of the rod to the begining of the n-th segment. - - Args: - xi_n (Array): array-like, shape (3,1) - A 3-dimensional vector representing the screw in SE(2). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the adjoint transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - theta = xi_n[0] # Angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE2(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - - Adjoint = ( - jnp.eye(3) - + 1 / (2 * theta) * (3 * sin - x * theta * cos) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 2)) - * (4 - 4 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 3)) - * (sin - x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - return Adjoint - - -def Adjoint_gn_SE2_inv( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the adjoint representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(2) deformed ine the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the lenght from the origin of the rod to the begining of the n-th segment. - - Args: - xi_n (Array): array-like, shape (3,1) - A 3-dimensional vector representing the screw in SE(2). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the adjoint transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - theta = xi_n[0] # Angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE2(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - - Adjoint = ( - jnp.eye(3) - + 1 / (2 * theta) * (3 * sin - x * theta * cos) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 2)) - * (4 - 4 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 3)) - * (sin - x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - # Extract R and -Jt from the Adjoint matrix - R = Adjoint[1:, 1:] - mJt = Adjoint[1:, 0].reshape(-1, 1) - - # Compute the inverse using the Schur complement - R_inv = jnp.transpose(R) # Since R is a rotation matrix, R^-1=R^T - # Construct the inverse Adjoint matrix - inverse_Adjoint = jnp.concatenate( - [ - jnp.concatenate([jnp.ones(((1, 1))), jnp.zeros((1, 2))], axis=1), - jnp.concatenate([-R_inv @ mJt, R_inv], axis=1), - ] - ) - - return inverse_Adjoint - - -def Tangent_gn_SE2( - xi_n: Array, - l_nprev: float, - s: float, -) -> Array: - """ - Computes the tangent representation of a position of a points at s (general curvilinear coordinate) - along a rod in SE(2) deformed in the current segment according to a strain vector xi_n. - - If s is a point of the n-th segment, this function use the length from the origin of the rod to the beginning of the n-th segment. - - Args: - xi_n (Array): array-like, shape (3,1) - A 3-dimensional vector representing the screw in SE(2). - l_nprev (float): - The length from the origin of the rod to the beginning of the n-th segment. - s (float): - The curvilinear coordinate along the rod, representing the position of a point in the n-th segment. - - Returns: - Array: shape (3, 3) - A 3x3 matrix representing the tangent transformation of the input screw vector at the specified position. - """ - # We suppose here that theta is not zero thanks to a previous use of apply_eps - theta = xi_n[0] # Angular part - x = s - l_nprev # Compute the segment length - adjoint_xi_n = adjoint_SE2(xi_n) # Adjoint representation of the input vector - - cos = jnp.cos(x * theta) - sin = jnp.sin(x * theta) - - Tangent = ( - x * jnp.eye(3) - + 1 / (2 * jnp.power(theta, 2)) * (4 - 4 * cos - x * theta * sin) * adjoint_xi_n - + 1 - / (2 * jnp.power(theta, 3)) - * (4 * x * theta - 5 * sin + x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 2) - + 1 - / (2 * jnp.power(theta, 4)) - * (2 - 2 * cos - x * theta * sin) - * jnp.linalg.matrix_power(adjoint_xi_n, 3) - + 1 - / (2 * jnp.power(theta, 5)) - * (2 * x * theta - 3 * sin + x * theta * cos) - * jnp.linalg.matrix_power(adjoint_xi_n, 4) - ) - - return Tangent - - -# ================================================================================================ -# Shared operators -# ============================ -def compute_weighted_sums(M: Array, vecm: Array, idx: int) -> Array: - """ - Compute the weighted sums of the matrix product of M and vecm, - - Args: - M (Array): array of shape (N, m, m) - Describes the matrix to be multiplied with vecm - vecm (Array): array-like of shape (N, m) - Describes the vector to be multiplied with M - idx (int): index of the last row to be summed over - - Returns: - Array: array of shape (N, m) - The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx. - """ - N = M.shape[0] - # Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m) - prod = jnp.einsum("nij,nj->ni", M, vecm) - - # Triangular mask for partial sum: (N, N) - # mask[i, j] = 1 if j >= i and j <= idx - mask = (jnp.arange(N)[:, None] <= jnp.arange(N)[None, :]) & ( - jnp.arange(N)[None, :] <= idx - ) - mask = mask.astype(M.dtype) # (N, N) - - # Extend 6-dimensional mask (N, N, 1) to apply to (N, m) - masked_prod = mask[:, :, None] * prod[None, :, :] # (N, N, m) - - # Sum over j for each i : (N, m) - result = masked_prod.sum(axis=1) # (N, m) - return result