|
| 1 | +from typing import TypeAlias, assert_type, type_check_only |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import numpy.typing as npt |
| 5 | + |
| 6 | +from scipy.integrate import solve_ivp |
| 7 | + |
| 8 | +_VecF64: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]] |
| 9 | +_MatF64: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float64]] |
| 10 | +_ArrF64: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.float64]] |
| 11 | +_VecC128: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]] |
| 12 | +_MatC128: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.complex128]] |
| 13 | +_ArrC128: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.complex128]] |
| 14 | + |
| 15 | +list_float: list[float] = ... |
| 16 | +list_complex: list[complex] = ... |
| 17 | + |
| 18 | +vec_f64: _VecF64 = ... |
| 19 | +arr_f64: _ArrF64 = ... |
| 20 | + |
| 21 | +vec_c128: _VecC128 = ... |
| 22 | +arr_c128: _ArrC128 = ... |
| 23 | + |
| 24 | +# NOTE: these examples are based on the `solve_ivp` docstring, and use common (suboptimal) type annotation patterns. |
| 25 | +### |
| 26 | + |
| 27 | +@type_check_only |
| 28 | +def exponential_decay(t: float, y: _ArrF64) -> _ArrF64: ... |
| 29 | + |
| 30 | +assert_type(solve_ivp(exponential_decay, list_float, list_float).t, _VecF64) |
| 31 | +assert_type(solve_ivp(exponential_decay, list_float, list_float).y, _MatF64) |
| 32 | +assert_type(solve_ivp(exponential_decay, list_float, list_float, args=()).y, _MatF64) |
| 33 | + |
| 34 | +### |
| 35 | + |
| 36 | +@type_check_only |
| 37 | +def upward_cannon(t: np.float64, y: _VecF64) -> list[float]: ... |
| 38 | +@type_check_only |
| 39 | +def hit_ground(t: np.float64, y: _VecF64) -> np.float64: ... |
| 40 | + |
| 41 | +assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground).y, _MatF64) |
| 42 | +assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, args=()).y, _MatF64) |
| 43 | +assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, dense_output=True).y, _MatF64) |
| 44 | + |
| 45 | +### |
| 46 | + |
| 47 | +@type_check_only |
| 48 | +def lotkavolterra( |
| 49 | + t: float, z: np.ndarray[tuple[int, ...], np.dtype[np.float64]], a: float, b: float, c: float, d: float |
| 50 | +) -> _VecF64: ... |
| 51 | + |
| 52 | +assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1)).y, _MatF64) |
| 53 | +assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1), dense_output=True).y, _MatF64) |
| 54 | + |
| 55 | +### |
| 56 | + |
| 57 | +@type_check_only |
| 58 | +def deriv_vec(t: float, y: npt.NDArray[np.float64 | np.complex128]) -> npt.NDArray[np.float64 | np.complex128]: ... |
| 59 | + |
| 60 | +assert_type(solve_ivp(deriv_vec, list_float, list_complex).y, _MatC128) |
| 61 | +assert_type(solve_ivp(deriv_vec, list_float, vec_c128).y, _MatC128) |
| 62 | +assert_type(solve_ivp(deriv_vec, list_float, arr_c128).y, _MatC128) |
| 63 | + |
| 64 | +assert_type(solve_ivp(deriv_vec, list_float, arr_c128, t_eval=list_float).y, _MatC128) |
| 65 | +assert_type(solve_ivp(deriv_vec, list_float, list_complex, t_eval=vec_f64).y, _MatC128) |
| 66 | +assert_type(solve_ivp(deriv_vec, list_float, vec_c128, t_eval=arr_f64).y, _MatC128) |
0 commit comments