Skip to content

Commit 5fec620

Browse files
committed
Fix some type hinting errors
1 parent b8c1fe6 commit 5fec620

File tree

5 files changed

+66
-63
lines changed

5 files changed

+66
-63
lines changed

examples/benchmark_planar_pcs.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import matplotlib.pyplot as plt
1010
import numpy as onp
1111
from pathlib import Path
12-
from typing import Callable, Dict, Optional, Literal
12+
from typing import Callable, Dict, Optional, Literal, Union
1313

1414
import jsrm
1515
from jsrm import ode_factory
16-
from jsrm.systems import planar_pcs, planar_pcs_num
16+
from jsrm.systems import planar_pcs_num, planar_pcs_sym
1717

1818
import time
1919
import pickle
@@ -25,21 +25,21 @@ def simulate_planar_pcs_value_eval(
2525
num_segments: int,
2626
type_of_derivation: Optional[Literal["symbolic", "numeric"]] = "symbolic",
2727
type_of_integration: Optional[Literal["gauss-legendre", "gauss-kronrad", "trapezoid"]] = "gauss-legendre",
28-
param_integration: int = None,
28+
param_integration: Optional[int] = None,
2929
type_of_jacobian: Optional[Literal["explicit", "autodiff"]] = "explicit",
30-
robot_params: Dict[str, Array] = None,
31-
strain_selector: Array = None,
32-
q0: Array = None,
33-
q_d0: Array = None,
30+
robot_params: Optional[Dict[str, Array]] = None,
31+
strain_selector: Optional[Array] = None,
32+
q0: Optional[Array] = None,
33+
q_d0: Optional[Array] = None,
3434
t: float = 1.0,
35-
dt: float = None,
35+
dt: Optional[float] = None,
3636
bool_print: bool = True,
3737
bool_plot: bool = True,
3838
bool_save_plot: bool = False,
3939
bool_save_video: bool = True,
4040
bool_save_res: bool = False,
41-
results_path: str = None,
42-
results_path_extension: str = None
41+
results_path: Optional[Union[str, Path]] = None,
42+
results_path_extension: Optional[str] = None
4343
) -> Dict:
4444
"""
4545
Simulate a planar PCS model. Save the video and figures.
@@ -81,7 +81,7 @@ def simulate_planar_pcs_value_eval(
8181
Defaults to None, which will use the default path.
8282
8383
Returns:
84-
TODO
84+
simulation_dict (Dict): dictionary with the simulation results.
8585
"""
8686

8787
# ===================================================
@@ -253,13 +253,12 @@ def simulate_planar_pcs_value_eval(
253253
results_path = (results_path_parent / file_name).with_suffix(".pkl")
254254

255255
if isinstance(results_path, str) or isinstance(results_path, Path):
256-
results_path = Path(results_path)
257-
if results_path.suffix != ".pkl":
256+
results_path_obj = Path(results_path)
257+
if results_path_obj.suffix != ".pkl":
258258
raise ValueError(
259-
f"results_path must have the suffix .pkl, but got {results_path.suffix}"
259+
f"results_path must have the suffix .pkl, but got {results_path_obj.suffix}"
260260
)
261-
else:
262-
results_path = Path(results_path)
261+
results_path = results_path_obj
263262
else:
264263
raise TypeError(
265264
f"results_path must be a string, but got {type(results_path).__name__}"
@@ -390,7 +389,7 @@ def draw_robot(
390389

391390
if type_of_derivation == "symbolic":
392391
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
393-
planar_pcs.factory(sym_exp_filepath, strain_selector)
392+
planar_pcs_sym.factory(sym_exp_filepath, strain_selector)
394393
)
395394
elif type_of_derivation == "numeric":
396395
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
@@ -441,7 +440,7 @@ def draw_robot(
441440
y0=x0,
442441
max_steps=None,
443442
saveat=SaveAt(ts=video_ts))
444-
diffeqsolve_fn = jit(diffeqsolve_fn)
443+
diffeqsolve_fn = jit(diffeqsolve_fn) # type: ignore
445444

446445
print("Solving the ODE ...")
447446
sol = diffeqsolve_fn()
@@ -463,7 +462,7 @@ def draw_robot(
463462
forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, robot_params, s=s_max)
464463

465464
print("JIT-compiling the forward kinematics function...")
466-
forward_kinematics_fn_end_effector = jit(forward_kinematics_fn_end_effector)
465+
forward_kinematics_fn_end_effector = jit(forward_kinematics_fn_end_effector) # type: ignore
467466
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)
468467

469468
print("Computing the end-effector position along the trajectory...")
@@ -619,7 +618,7 @@ def draw_robot(
619618
)
620619

621620
# Initialize the video
622-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
621+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
623622
video = cv2.VideoWriter(
624623
str(video_path),
625624
fourcc,
@@ -649,6 +648,7 @@ def draw_robot(
649648
# ===========================
650649
if bool_save_res:
651650
print("Saving the simulation results...")
651+
assert results_path is not None, "results_path should not be None when saving results"
652652
with open(results_path, "wb") as f:
653653
pickle.dump(simulation_dict, f)
654654
print(f"Simulation results saved at {results_path} \n")
@@ -665,20 +665,20 @@ def simulate_planar_pcs_time_eval(
665665
num_segments: int,
666666
type_of_derivation: str = "symbolic",
667667
type_of_integration: str = "gauss-legendre",
668-
param_integration: int = None,
668+
param_integration: Optional[int] = None,
669669
type_of_jacobian: str = "explicit",
670-
robot_params: Dict[str, Array] = None,
671-
strain_selector: Array = None,
672-
q0: Array = None,
673-
q_d0: Array = None,
670+
robot_params: Optional[Dict[str, Array]] = None,
671+
strain_selector: Optional[Array] = None,
672+
q0: Optional[Array] = None,
673+
q_d0: Optional[Array] = None,
674674
t: float = 1.0,
675-
dt: float = None,
675+
dt: Optional[float] = None,
676676
bool_save_res: bool = False,
677-
results_path: str = None,
678-
results_path_extension: str = None,
679-
type_time = "once",
680-
nb_eval : int = None,
681-
nb_samples: int = None
677+
results_path: Optional[Union[str, Path]] = None,
678+
results_path_extension: Optional[str] = None,
679+
type_time: str = "once",
680+
nb_eval: Optional[int] = None,
681+
nb_samples: Optional[int] = None
682682
) -> Dict:
683683
"""
684684
Simulate a planar PCS model. Save the video and figures.
@@ -922,13 +922,12 @@ def simulate_planar_pcs_time_eval(
922922
results_path = (results_path_parent / file_name).with_suffix(".pkl")
923923

924924
if isinstance(results_path, str) or isinstance(results_path, Path):
925-
results_path = Path(results_path)
926-
if results_path.suffix != ".pkl":
925+
results_path_obj = Path(results_path)
926+
if results_path_obj.suffix != ".pkl":
927927
raise ValueError(
928-
f"results_path must have the suffix .pkl, but got {results_path.suffix}"
928+
f"results_path must have the suffix .pkl, but got {results_path_obj.suffix}"
929929
)
930-
else:
931-
results_path = Path(results_path)
930+
results_path = results_path_obj
932931
else:
933932
raise TypeError(
934933
f"results_path must be a string, but got {type(results_path).__name__}"
@@ -982,7 +981,7 @@ def simulate_planar_pcs_time_eval(
982981
timer_start = time.time()
983982
if type_of_derivation == "symbolic":
984983
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
985-
planar_pcs.factory(sym_exp_filepath, strain_selector)
984+
planar_pcs_sym.factory(sym_exp_filepath, strain_selector)
986985
)
987986
elif type_of_derivation == "numeric":
988987
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
@@ -1330,7 +1329,7 @@ def time_diffeqsolve_over_time():
13301329

13311330
# JIT the functions
13321331
print("JIT-compiling the ODE function...")
1333-
diffeqsolve_fn = jit(diffeqsolve_fn)
1332+
diffeqsolve_fn = jit(diffeqsolve_fn) # type: ignore
13341333

13351334
# First evaluation of the ODE to trigger JIT compilation
13361335
print("Solving the ODE for the first time (JIT-compilation)...")
@@ -1579,6 +1578,7 @@ def time_kinetic_energy_over_time():
15791578
# ===========================
15801579
if bool_save_res:
15811580
print("Saving the simulation results...")
1581+
assert results_path is not None, "results_path should not be None when saving results"
15821582
with open(results_path, "wb") as f:
15831583
pickle.dump(simulation_dict, f)
15841584
print(f"Simulation results saved at {results_path} \n")

src/jsrm/systems/planar_pcs_num.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from quadax import GaussKronrodRule
66

77
import numpy as onp
8-
from typing import Callable, Dict, Tuple, Optional, Literal
8+
from typing import Callable, Dict, Tuple, Optional, Literal, Any
99

1010
from .utils import (
1111
compute_strain_basis,
@@ -33,25 +33,25 @@
3333

3434
def factory(
3535
num_segments: int,
36-
strain_selector: Array=None,
36+
strain_selector: Optional[Array] = None,
3737
xi_eq: Optional[Array] = None,
3838
stiffness_fn: Optional[Callable] = None,
3939
actuation_mapping_fn: Optional[Callable] = None,
4040
global_eps: float = jnp.finfo(jnp.float32).eps,
4141
integration_type: Optional[Literal["gauss-legendre", "gauss-kronrad", "trapezoid"]] = "gauss-legendre",
42-
param_integration: int = None,
42+
param_integration: Optional[int] = None,
4343
jacobian_type: Optional[Literal["explicit", "autodiff"]] = "explicit"
44-
) -> Tuple[
44+
) -> Tuple[
4545
Array,
4646
Callable[
47-
[Dict[str, Array], Array, Array, Optional[float]],
47+
[Dict[str, Array], Array, Array, float],
4848
Array
4949
],
5050
Callable[
51-
[Dict[str, Array], Array, Array, Optional[float]],
51+
[Dict[str, Array], Array, Array, float],
5252
Tuple[Array, Array, Array, Array, Array, Array],
5353
],
54-
Dict[str, Callable],
54+
Dict[str, Callable[..., Any]],
5555
]:
5656
"""
5757
Factory function to create the forward kinematics function for a planar robot.
@@ -190,7 +190,7 @@ def stiffness_fn(
190190
S = B_xi.T @ S @ B_xi
191191

192192
return S
193-
if not isinstance(stiffness_fn, Callable):
193+
if not isinstance(stiffness_fn, callable):
194194
raise TypeError(f"stiffness_fn must be a callable, but got {type(stiffness_fn).__name__}")
195195

196196
# Actuation mapping function
@@ -295,7 +295,7 @@ def apply_eps_to_bend_strains(
295295
def classify_segment(
296296
params: Dict[str, Array],
297297
s: Array
298-
) -> Tuple[Array, Array]:
298+
) -> Tuple[Array, Array, Array]:
299299
"""
300300
Classify the point along the robot to the corresponding segment.
301301
@@ -306,6 +306,7 @@ def classify_segment(
306306
Returns:
307307
segment_idx (Array): index of the segment where the point is located
308308
s_segment (Array): point coordinate along the segment in the interval [0, l_segment]
309+
l_cum (Array): cumulative length of the segments starting with 0
309310
"""
310311
l = params["l"]
311312

@@ -501,7 +502,7 @@ def J_explicit_local(
501502
def J_i(
502503
tuple_J_prev: Array,
503504
i: int
504-
) -> Array:
505+
) -> Tuple[Tuple[Array, Array], Array]:
505506
J_prev_Lprev, _ = tuple_J_prev
506507

507508
start_index = 3 * i
@@ -534,7 +535,8 @@ def J_i(
534535
_, J_array = lax.scan(
535536
f = J_i,
536537
init = tuple_J_0,
537-
xs = jnp.arange(1, num_segments))
538+
xs = jnp.arange(1, num_segments)
539+
)
538540

539541
# Add the initial condition to the Jacobian array
540542
J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0)
@@ -597,7 +599,7 @@ def J_explicit_global(
597599
def J_i(
598600
tuple_J_prev: Array,
599601
i: int
600-
) -> Array:
602+
) -> Tuple[Tuple[Array, Array], Array]:
601603
J_prev_Lprev, _ = tuple_J_prev
602604

603605
start_index = 3 * i
@@ -708,7 +710,7 @@ def J_Jd_explicit_local(
708710
xi_d: Array,
709711
s: Array,
710712
eps: float = global_eps
711-
) -> Array:
713+
) -> Tuple[Array, Array]:
712714
"""
713715
Compute the body-frame jacobian and its derivative with respect to the strain vector
714716
at a given point s using explicit expression in SE(2).
@@ -753,7 +755,7 @@ def J_Jd_explicit_local(
753755
def J_i(
754756
tuple_J_prev: Array,
755757
i: int
756-
) -> Array:
758+
) -> Tuple[Tuple[Array, Array], Array]:
757759
J_prev_Lprev, _ = tuple_J_prev
758760

759761
start_index = 3 * i
@@ -837,7 +839,7 @@ def J_Jd_explicit_global(
837839
xi_d: Array,
838840
s: Array,
839841
eps: float = global_eps
840-
) -> Array:
842+
) -> Tuple[Array, Array]:
841843
"""
842844
Compute the inertial-frame jacobian and its derivative with respect to the strain vector
843845
at a given point s using explicit expression in SE(2).
@@ -882,7 +884,7 @@ def J_Jd_explicit_global(
882884
def J_i(
883885
tuple_J_prev: Array,
884886
i: int
885-
) -> Array:
887+
) -> Tuple[Tuple[Array, Array], Array]:
886888
J_prev_Lprev, _ = tuple_J_prev
887889

888890
start_index = 3 * i
@@ -915,7 +917,8 @@ def J_i(
915917
_, J_array = lax.scan(
916918
f = J_i,
917919
init = tuple_J_0,
918-
xs = jnp.arange(1, num_segments))
920+
xs = jnp.arange(1, num_segments)
921+
)
919922

920923
# Add the initial condition to the Jacobian array
921924
J_array = jnp.concatenate([J_0_s[jnp.newaxis, ...], J_array], axis=0)
@@ -1705,7 +1708,7 @@ def operational_space_dynamical_matrices_fn(
17051708

17061709
return Lambda, mu, J, J_d, JB_pinv
17071710

1708-
auxiliary_fns = {
1711+
auxiliary_fns: Dict[str, Callable[..., Any]] = {
17091712
"apply_eps_to_bend_strains": apply_eps_to_bend_strains,
17101713
"classify_segment": classify_segment,
17111714
"stiffness_fn": stiffness_fn,
@@ -1717,4 +1720,4 @@ def operational_space_dynamical_matrices_fn(
17171720
"operational_space_dynamical_matrices_fn": operational_space_dynamical_matrices_fn,
17181721
}
17191722

1720-
return B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns
1723+
return B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns

tests/test_fwd_kine_eps_planar_pcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from matplotlib import rc
66
rc('animation', html='html5')
77

8-
from jsrm.systems import planar_pcs, planar_pcs_num
8+
from jsrm.systems import planar_pcs_num, planar_pcs_sym
99
from pathlib import Path
1010
import jsrm
1111
from tqdm import tqdm
@@ -33,7 +33,7 @@
3333
def get_fwd_kine_fn(jacobian_type):
3434
if jacobian_type == "symbolic":
3535
sym_exp_filepath = Path(jsrm.__file__).parent / "symbolic_expressions" / f"planar_pcs_ns-{num_segments}.dill"
36-
_, fwd, _, _ = planar_pcs.factory(sym_exp_filepath, strain_selector)
36+
_, fwd, _, _ = planar_pcs_sym.factory(sym_exp_filepath, strain_selector)
3737
else:
3838
_, fwd, _, _ = planar_pcs_num.factory(
3939
num_segments, strain_selector,

tests/test_jacobian_eps_planar_pcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from matplotlib import rc
66
rc('animation', html='html5')
77

8-
from jsrm.systems import planar_pcs, planar_pcs_num
8+
from jsrm.systems import planar_pcs_num, planar_pcs_sym
99
from pathlib import Path
1010
import jsrm
1111
from tqdm import tqdm
@@ -33,7 +33,7 @@
3333
def get_jacobian_fn(jacobian_type):
3434
if jacobian_type == "symbolic":
3535
sym_exp_filepath = Path(jsrm.__file__).parent / "symbolic_expressions" / f"planar_pcs_ns-{num_segments}.dill"
36-
_, _, _, aux = planar_pcs.factory(sym_exp_filepath, strain_selector)
36+
_, _, _, aux = planar_pcs_sym.factory(sym_exp_filepath, strain_selector)
3737
else:
3838
_, _, _, aux = planar_pcs_num.factory(
3939
num_segments, strain_selector,

0 commit comments

Comments
 (0)