Skip to content

Commit e8d8153

Browse files
committed
integrate: add type-tests for solve_ivp
1 parent 297f6e6 commit e8d8153

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/integrate/test_solve_ivp.pyi

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import TypeAlias, assert_type, type_check_only
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from scipy.integrate import solve_ivp
7+
8+
_VecF64: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]]
9+
_MatF64: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float64]]
10+
_ArrF64: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.float64]]
11+
_VecC128: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]]
12+
_MatC128: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.complex128]]
13+
_ArrC128: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.complex128]]
14+
15+
list_float: list[float] = ...
16+
list_complex: list[complex] = ...
17+
18+
vec_f64: _VecF64 = ...
19+
arr_f64: _ArrF64 = ...
20+
21+
vec_c128: _VecC128 = ...
22+
arr_c128: _ArrC128 = ...
23+
24+
# NOTE: these examples are based on the `solve_ivp` docstring, and use common (suboptimal) type annotation patterns.
25+
###
26+
27+
@type_check_only
28+
def exponential_decay(t: float, y: _ArrF64) -> _ArrF64: ...
29+
30+
assert_type(solve_ivp(exponential_decay, list_float, list_float).t, _VecF64)
31+
assert_type(solve_ivp(exponential_decay, list_float, list_float).y, _MatF64)
32+
assert_type(solve_ivp(exponential_decay, list_float, list_float, args=()).y, _MatF64)
33+
34+
###
35+
36+
@type_check_only
37+
def upward_cannon(t: np.float64, y: _VecF64) -> list[float]: ...
38+
@type_check_only
39+
def hit_ground(t: np.float64, y: _VecF64) -> np.float64: ...
40+
41+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground).y, _MatF64)
42+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, args=()).y, _MatF64)
43+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, dense_output=True).y, _MatF64)
44+
45+
###
46+
47+
@type_check_only
48+
def lotkavolterra(
49+
t: float, z: np.ndarray[tuple[int, ...], np.dtype[np.float64]], a: float, b: float, c: float, d: float
50+
) -> _VecF64: ...
51+
52+
assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1)).y, _MatF64)
53+
assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1), dense_output=True).y, _MatF64)
54+
55+
###
56+
57+
@type_check_only
58+
def deriv_vec(t: float, y: npt.NDArray[np.float64 | np.complex128]) -> npt.NDArray[np.float64 | np.complex128]: ...
59+
60+
assert_type(solve_ivp(deriv_vec, list_float, list_complex).y, _MatC128)
61+
assert_type(solve_ivp(deriv_vec, list_float, vec_c128).y, _MatC128)
62+
assert_type(solve_ivp(deriv_vec, list_float, arr_c128).y, _MatC128)
63+
64+
assert_type(solve_ivp(deriv_vec, list_float, arr_c128, t_eval=list_float).y, _MatC128)
65+
assert_type(solve_ivp(deriv_vec, list_float, list_complex, t_eval=vec_f64).y, _MatC128)
66+
assert_type(solve_ivp(deriv_vec, list_float, vec_c128, t_eval=arr_f64).y, _MatC128)

0 commit comments

Comments
 (0)