Skip to content

Commit 8c58036

Browse files
authored
integrate: improved solve_ivp annotations (#765)
2 parents e36d9dd + e8d8153 commit 8c58036

File tree

3 files changed

+220
-70
lines changed

3 files changed

+220
-70
lines changed
Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections.abc import Callable, Sequence
2-
from typing import Final, Literal, TypeAlias, TypeVar, overload
1+
from collections.abc import Callable
2+
from typing import Any, Final, Generic, Literal, TypeAlias, overload
3+
from typing_extensions import TypeVar
34

45
import numpy as np
56
import optype as op
@@ -9,13 +10,10 @@ import optype.numpy.compat as npc
910
from .base import DenseOutput
1011
from scipy.sparse import csc_matrix
1112

12-
_SCT = TypeVar("_SCT", bound=np.generic)
13+
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
1314
_ToFloatT = TypeVar("_ToFloatT", bound=onp.ToFloat)
15+
_InterpT_co = TypeVar("_InterpT_co", bound=DenseOutput, default=DenseOutput, covariant=True)
1416

15-
_Side: TypeAlias = Literal["left", "right"]
16-
_Interpolants: TypeAlias = Sequence[DenseOutput]
17-
18-
_To1D: TypeAlias = Sequence[_SCT] | onp.CanArrayND[_SCT]
1917
_ToFloat64: TypeAlias = np.float16 | np.float32 | np.float64 | npc.integer | np.bool_
2018

2119
###
@@ -28,43 +26,41 @@ NUM_JAC_MIN_FACTOR: Final[float] = ...
2826
NUM_JAC_FACTOR_INCREASE: Final[float] = 10
2927
NUM_JAC_FACTOR_DECREASE: Final[float] = 0.1
3028

31-
class OdeSolution:
29+
class OdeSolution(Generic[_InterpT_co]):
30+
interpolants: list[_InterpT_co]
3231
ts: onp.Array1D[np.float64]
3332
ts_sorted: onp.Array1D[np.float64]
3433
t_min: np.float64
3534
t_max: np.float64
3635
ascending: bool
37-
side: _Side
36+
side: Literal["left", "right"]
3837
n_segments: int
39-
interpolants: _Interpolants
4038

41-
def __init__(self, /, ts: onp.ToFloat1D, interpolants: _Interpolants, alt_segment: op.CanBool = False) -> None: ...
39+
def __init__(self, /, ts: onp.ToFloat1D, interpolants: list[_InterpT_co], alt_segment: op.CanBool = False) -> None: ...
40+
41+
#
4242
@overload
4343
def __call__(self, /, t: float | _ToFloat64) -> onp.Array1D[np.float64]: ...
4444
@overload
45-
def __call__(self, /, t: np.complex64 | np.complex128) -> onp.Array1D[np.complex128]: ...
45+
def __call__(self, /, t: op.JustComplex | np.complex128 | np.complex64) -> onp.Array1D[np.complex128]: ...
4646
@overload
4747
def __call__(self, /, t: np.longdouble) -> onp.Array1D[np.longdouble]: ...
4848
@overload
4949
def __call__(self, /, t: np.clongdouble) -> onp.Array1D[np.clongdouble]: ...
5050
@overload
51-
def __call__(self, /, t: complex) -> onp.Array1D[np.float64 | np.complex128]: ...
52-
@overload
53-
def __call__(self, /, t: Sequence[float | _ToFloat64] | onp.CanArrayND[_ToFloat64]) -> onp.Array2D[np.float64]: ...
54-
@overload
55-
def __call__(self, /, t: _To1D[np.complex64 | np.complex128]) -> onp.Array2D[np.complex128]: ...
51+
def __call__(self, /, t: onp.ToArray1D[float, _ToFloat64]) -> onp.Array2D[np.float64]: ...
5652
@overload
57-
def __call__(self, /, t: _To1D[np.clongdouble]) -> onp.Array2D[np.clongdouble]: ...
53+
def __call__(self, /, t: onp.ToArray1D[op.JustComplex, np.complex128 | np.complex64]) -> onp.Array2D[np.complex128]: ...
5854
@overload
59-
def __call__(self, /, t: Sequence[complex]) -> onp.Array2D[np.float64 | np.complex128]: ...
55+
def __call__(self, /, t: onp.CanArrayND[np.clongdouble]) -> onp.Array2D[np.clongdouble]: ...
6056

61-
def validate_first_step(first_step: _ToFloatT, t0: onp.ToFloat, t_bound: onp.ToFloat) -> _ToFloatT: ...
62-
def validate_max_step(max_step: _ToFloatT) -> _ToFloatT: ...
63-
def warn_extraneous(extraneous: dict[str, object]) -> None: ...
57+
def validate_first_step(first_step: _ToFloatT, t0: onp.ToFloat, t_bound: onp.ToFloat) -> _ToFloatT: ... # undocumented
58+
def validate_max_step(max_step: _ToFloatT) -> _ToFloatT: ... # undocumented
59+
def warn_extraneous(extraneous: dict[str, Any]) -> None: ... # undocumented
6460
def validate_tol(
65-
rtol: onp.ArrayND[npc.floating], atol: onp.ArrayND[npc.floating], n: int
66-
) -> tuple[onp.Array1D[npc.floating], onp.Array1D[npc.floating]]: ...
67-
def norm(x: onp.ToFloatND) -> npc.floating: ...
61+
rtol: onp.ArrayND[_FloatingT], atol: onp.ArrayND[_FloatingT], n: int
62+
) -> tuple[onp.Array1D[_FloatingT], onp.Array1D[_FloatingT]]: ... # undocumented
63+
def norm(x: onp.ToFloatND) -> npc.floating: ... # undocumented
6864
def select_initial_step(
6965
fun: Callable[[np.float64, onp.Array1D[np.float64]], onp.Array1D[np.float64]],
7066
t0: float | np.float64,
@@ -76,7 +72,7 @@ def select_initial_step(
7672
order: float | np.float64,
7773
rtol: float | np.float64,
7874
atol: float | np.float64,
79-
) -> float | np.float64: ...
75+
) -> float: ... # undocumented
8076
def num_jac(
8177
fun: Callable[[np.float64, onp.Array1D[np.float64]], onp.Array1D[np.float64]],
8278
t: float | np.float64,
@@ -85,4 +81,4 @@ def num_jac(
8581
threshold: float | np.float64,
8682
factor: onp.ArrayND[np.float64] | None,
8783
sparsity: tuple[csc_matrix, onp.ArrayND[np.intp]] | None = None,
88-
) -> tuple[onp.Array2D[np.float64] | csc_matrix, onp.Array1D[np.float64]]: ...
84+
) -> tuple[onp.Array2D[np.float64] | csc_matrix[np.float64], onp.Array1D[np.float64]]: ... # undocumented

scipy-stubs/integrate/_ivp/ivp.pyi

Lines changed: 131 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Sequence
2-
from typing import Concatenate, Final, Generic, Literal, TypeAlias, overload, type_check_only
2+
from typing import Any, Final, Generic, Literal, TypeAlias, TypeVarTuple, overload, type_check_only
33
from typing_extensions import TypeVar, TypedDict, Unpack
44

55
import numpy as np
@@ -10,58 +10,63 @@ from .base import DenseOutput, OdeSolver
1010
from .common import OdeSolution
1111
from scipy._lib._util import _RichResult
1212
from scipy.sparse import sparray, spmatrix
13+
from scipy.sparse._base import _spbase
1314

14-
_SCT_cf = TypeVar("_SCT_cf", bound=npc.inexact, default=np.float64 | np.complex128)
15+
_Ts = TypeVarTuple("_Ts")
16+
_ScalarT = TypeVar("_ScalarT", bound=npc.number | np.bool)
17+
_Inexact64T = TypeVar("_Inexact64T", bound=np.float64 | np.complex128)
18+
_Inexact64T_co = TypeVar("_Inexact64T_co", bound=np.float64 | np.complex128, default=np.float64 | np.complex128, covariant=True)
1519

16-
_FuncSol: TypeAlias = Callable[[float], onp.ArrayND[_SCT_cf]]
17-
_FuncEvent: TypeAlias = Callable[[float, onp.ArrayND[_SCT_cf]], float]
18-
_Events: TypeAlias = Sequence[_FuncEvent[_SCT_cf]]
20+
_FuncSol: TypeAlias = Callable[[np.float64], onp.ArrayND[_Inexact64T]]
21+
_FuncEvent: TypeAlias = Callable[[np.float64, onp.ArrayND[_Inexact64T], *_Ts], float]
22+
_Events: TypeAlias = Sequence[_FuncEvent[_Inexact64T, *_Ts]] | _FuncEvent[_Inexact64T, *_Ts]
1923

20-
_Int1D: TypeAlias = onp.Array1D[np.intp]
24+
_Int1D: TypeAlias = onp.Array1D[np.int_]
2125
_Float1D: TypeAlias = onp.Array1D[np.float64]
26+
_Float2D: TypeAlias = onp.Array2D[np.float64]
27+
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
28+
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
2229

23-
_ToJac: TypeAlias = onp.ToComplex2D | spmatrix | sparray
30+
_Sparse2D: TypeAlias = _spbase[_ScalarT, tuple[int, int]] | sparray[_ScalarT, tuple[int, int]] | spmatrix[_ScalarT]
31+
_ToJac: TypeAlias = onp.ToArray2D[complex, npc.inexact] | _Sparse2D[npc.inexact]
2432

25-
_IVPMethod: TypeAlias = Literal["RK23", "RK45", "DOP853", "Radau", "BDF", "LSODA"]
33+
_IVPMethod: TypeAlias = Literal["RK23", "RK45", "DOP853", "Radau", "BDF", "LSODA"] | type[OdeSolver]
2634

2735
@type_check_only
28-
class _SolverOptions(TypedDict, Generic[_SCT_cf], total=False):
29-
first_step: onp.ToFloat | None
30-
max_step: onp.ToFloat
31-
rtol: onp.ToFloat | onp.ToFloat1D
32-
atol: onp.ToFloat | onp.ToFloat1D
33-
jac: _ToJac | Callable[[float, onp.Array1D[np.float64]], _ToJac] | None
34-
jac_sparsity: onp.ToFloat2D | spmatrix | sparray | None
35-
lband: onp.ToInt | None
36-
uband: onp.ToInt | None
37-
min_step: onp.ToFloat
36+
class _SolverOptions(TypedDict, total=False):
37+
first_step: float | None
38+
max_step: float
39+
rtol: float | onp.ToFloat1D
40+
atol: float | onp.ToFloat1D
41+
jac: _ToJac | Callable[[np.float64, onp.Array1D], _ToJac] | None
42+
jac_sparsity: onp.ToFloat2D | _Sparse2D[npc.floating] | None
43+
lband: int | None
44+
uband: int | None
45+
min_step: float
3846

3947
###
4048

4149
METHODS: Final[dict[str, type]] = ...
4250
MESSAGES: Final[dict[int, str]] = ...
4351

44-
class OdeResult(
45-
_RichResult[int | str | onp.ArrayND[np.float64 | _SCT_cf] | list[onp.ArrayND[np.float64 | _SCT_cf]] | OdeSolution | None],
46-
Generic[_SCT_cf],
47-
):
52+
class OdeResult(_RichResult[Any], Generic[_Inexact64T_co]):
4853
t: _Float1D
49-
y: onp.Array2D[_SCT_cf]
54+
y: onp.Array2D[_Inexact64T_co]
5055
sol: OdeSolution | None
5156
t_events: list[_Float1D] | None
52-
y_events: list[onp.ArrayND[_SCT_cf]] | None
57+
y_events: list[onp.ArrayND[_Inexact64T_co]] | None
5358
nfev: int
5459
njev: int
5560
nlu: int
5661
status: Literal[-1, 0, 1]
5762
message: str
5863
success: bool
5964

60-
def prepare_events(events: _FuncEvent[_SCT_cf] | _Events[_SCT_cf]) -> tuple[_Events[_SCT_cf], _Float1D, _Float1D]: ...
61-
def solve_event_equation(event: _FuncEvent[_SCT_cf], sol: _FuncSol[_SCT_cf], t_old: float, t: float) -> float: ...
65+
def prepare_events(events: _Events[_Inexact64T]) -> tuple[_Events[_Inexact64T], _Float1D, _Float1D]: ...
66+
def solve_event_equation(event: _FuncEvent[_Inexact64T], sol: _FuncSol[_Inexact64T], t_old: float, t: float) -> float: ...
6267
def handle_events(
6368
sol: DenseOutput,
64-
events: Sequence[_FuncEvent[_SCT_cf]],
69+
events: Sequence[_FuncEvent[_Inexact64T]],
6570
active_events: onp.ArrayND[np.intp],
6671
event_count: onp.ArrayND[np.intp | np.float64],
6772
max_events: onp.ArrayND[np.intp | np.float64],
@@ -71,30 +76,113 @@ def handle_events(
7176
def find_active_events(g: onp.ToFloat1D, g_new: onp.ToFloat1D, direction: onp.ArrayND[np.float64]) -> _Int1D: ...
7277

7378
#
74-
@overload
79+
@overload # float, vectorized=False (default), args=None (default)
7580
def solve_ivp(
76-
fun: Callable[Concatenate[float, onp.Array1D[_SCT_cf], ...], onp.ArrayND[_SCT_cf]],
77-
t_span: Sequence[onp.ToFloat],
78-
y0: onp.ToArray1D,
79-
method: _IVPMethod | type[OdeSolver] = "RK45",
81+
fun: Callable[[np.float64, _Float1D], onp.ToFloat1D | float],
82+
t_span: Sequence[float],
83+
y0: onp.ToFloat1D,
84+
method: _IVPMethod = "RK45",
8085
t_eval: onp.ToFloat1D | None = None,
8186
dense_output: bool = False,
82-
events: _Events[_SCT_cf] | None = None,
87+
events: _Events[np.float64] | None = None,
8388
vectorized: onp.ToFalse = False,
84-
args: tuple[object, ...] | None = None,
89+
args: None = None,
8590
**options: Unpack[_SolverOptions],
86-
) -> OdeResult[_SCT_cf]: ...
87-
@overload
91+
) -> OdeResult[np.float64]: ...
92+
@overload # float, vectorized=False (default), args=<given>
8893
def solve_ivp(
89-
fun: Callable[Concatenate[_Float1D, onp.Array2D[_SCT_cf], ...], onp.ArrayND[_SCT_cf]],
90-
t_span: Sequence[onp.ToFloat],
91-
y0: onp.ToArray1D,
92-
method: _IVPMethod | type[OdeSolver] = "RK45",
94+
fun: Callable[[np.float64, _Float1D, *_Ts], onp.ToFloat1D | float],
95+
t_span: Sequence[float],
96+
y0: onp.ToFloat1D,
97+
method: _IVPMethod = "RK45",
9398
t_eval: onp.ToFloat1D | None = None,
9499
dense_output: bool = False,
95-
events: _Events[_SCT_cf] | None = None,
100+
events: _Events[np.float64] | None = None,
101+
vectorized: onp.ToFalse = False,
102+
*,
103+
args: tuple[*_Ts],
104+
**options: Unpack[_SolverOptions],
105+
) -> OdeResult[np.float64]: ...
106+
@overload # float, vectorized=True, args=None (default)
107+
def solve_ivp(
108+
fun: Callable[[_Float1D, _Float2D], onp.ToFloat2D],
109+
t_span: Sequence[float],
110+
y0: onp.ToFloat1D,
111+
method: _IVPMethod = "RK45",
112+
t_eval: onp.ToFloat1D | None = None,
113+
dense_output: bool = False,
114+
events: _Events[np.float64] | None = None,
115+
*,
116+
vectorized: onp.ToTrue,
117+
args: None = None,
118+
**options: Unpack[_SolverOptions],
119+
) -> OdeResult[np.float64]: ...
120+
@overload # float, vectorized=True, args=<given>
121+
def solve_ivp(
122+
fun: Callable[[_Float1D, _Float2D, *_Ts], onp.ToFloat2D],
123+
t_span: Sequence[float],
124+
y0: onp.ToFloat1D,
125+
method: _IVPMethod = "RK45",
126+
t_eval: onp.ToFloat1D | None = None,
127+
dense_output: bool = False,
128+
events: _Events[np.float64] | None = None,
129+
*,
130+
vectorized: onp.ToTrue,
131+
args: tuple[*_Ts],
132+
**options: Unpack[_SolverOptions],
133+
) -> OdeResult[np.float64]: ...
134+
@overload # complex, vectorized=False (default), args=None (default)
135+
def solve_ivp(
136+
fun: Callable[[np.float64, _Complex1D], onp.ToComplex1D | complex],
137+
t_span: Sequence[float],
138+
y0: onp.ToComplex1D,
139+
method: _IVPMethod = "RK45",
140+
t_eval: onp.ToFloat1D | None = None,
141+
dense_output: bool = False,
142+
events: _Events[np.complex128] | None = None,
143+
vectorized: onp.ToFalse = False,
144+
args: None = None,
145+
**options: Unpack[_SolverOptions],
146+
) -> OdeResult[np.complex128]: ...
147+
@overload # complex, vectorized=False (default), args=<given>
148+
def solve_ivp(
149+
fun: Callable[[np.float64, _Complex1D, *_Ts], onp.ToComplex1D | complex],
150+
t_span: Sequence[float],
151+
y0: onp.ToComplex1D,
152+
method: _IVPMethod = "RK45",
153+
t_eval: onp.ToFloat1D | None = None,
154+
dense_output: bool = False,
155+
events: _Events[np.complex128] | None = None,
156+
vectorized: onp.ToFalse = False,
157+
*,
158+
args: tuple[*_Ts],
159+
**options: Unpack[_SolverOptions],
160+
) -> OdeResult[np.complex128]: ...
161+
@overload # complex, vectorized=True, args=None (default)
162+
def solve_ivp(
163+
fun: Callable[[_Float1D, _Complex2D], onp.ToComplex2D],
164+
t_span: Sequence[float],
165+
y0: onp.ToComplex1D,
166+
method: _IVPMethod = "RK45",
167+
t_eval: onp.ToFloat1D | None = None,
168+
dense_output: bool = False,
169+
events: _Events[np.complex128] | None = None,
170+
*,
171+
vectorized: onp.ToTrue,
172+
args: None = None,
173+
**options: Unpack[_SolverOptions],
174+
) -> OdeResult[np.complex128]: ...
175+
@overload # complex, vectorized=True, args=<given>
176+
def solve_ivp(
177+
fun: Callable[[_Float1D, _Complex2D, *_Ts], onp.ToComplex2D],
178+
t_span: Sequence[float],
179+
y0: onp.ToComplex1D,
180+
method: _IVPMethod = "RK45",
181+
t_eval: onp.ToFloat1D | None = None,
182+
dense_output: bool = False,
183+
events: _Events[np.complex128] | None = None,
96184
*,
97185
vectorized: onp.ToTrue,
98-
args: tuple[object, ...] | None = None,
186+
args: tuple[*_Ts],
99187
**options: Unpack[_SolverOptions],
100-
) -> OdeResult[_SCT_cf]: ...
188+
) -> OdeResult[np.complex128]: ...

tests/integrate/test_solve_ivp.pyi

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import TypeAlias, assert_type, type_check_only
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from scipy.integrate import solve_ivp
7+
8+
_VecF64: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]]
9+
_MatF64: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.float64]]
10+
_ArrF64: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.float64]]
11+
_VecC128: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]]
12+
_MatC128: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.complex128]]
13+
_ArrC128: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.complex128]]
14+
15+
list_float: list[float] = ...
16+
list_complex: list[complex] = ...
17+
18+
vec_f64: _VecF64 = ...
19+
arr_f64: _ArrF64 = ...
20+
21+
vec_c128: _VecC128 = ...
22+
arr_c128: _ArrC128 = ...
23+
24+
# NOTE: these examples are based on the `solve_ivp` docstring, and use common (suboptimal) type annotation patterns.
25+
###
26+
27+
@type_check_only
28+
def exponential_decay(t: float, y: _ArrF64) -> _ArrF64: ...
29+
30+
assert_type(solve_ivp(exponential_decay, list_float, list_float).t, _VecF64)
31+
assert_type(solve_ivp(exponential_decay, list_float, list_float).y, _MatF64)
32+
assert_type(solve_ivp(exponential_decay, list_float, list_float, args=()).y, _MatF64)
33+
34+
###
35+
36+
@type_check_only
37+
def upward_cannon(t: np.float64, y: _VecF64) -> list[float]: ...
38+
@type_check_only
39+
def hit_ground(t: np.float64, y: _VecF64) -> np.float64: ...
40+
41+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground).y, _MatF64)
42+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, args=()).y, _MatF64)
43+
assert_type(solve_ivp(upward_cannon, list_float, list_float, events=hit_ground, dense_output=True).y, _MatF64)
44+
45+
###
46+
47+
@type_check_only
48+
def lotkavolterra(
49+
t: float, z: np.ndarray[tuple[int, ...], np.dtype[np.float64]], a: float, b: float, c: float, d: float
50+
) -> _VecF64: ...
51+
52+
assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1)).y, _MatF64)
53+
assert_type(solve_ivp(lotkavolterra, list_float, list_float, args=(1.5, 1, 3, 1), dense_output=True).y, _MatF64)
54+
55+
###
56+
57+
@type_check_only
58+
def deriv_vec(t: float, y: npt.NDArray[np.float64 | np.complex128]) -> npt.NDArray[np.float64 | np.complex128]: ...
59+
60+
assert_type(solve_ivp(deriv_vec, list_float, list_complex).y, _MatC128)
61+
assert_type(solve_ivp(deriv_vec, list_float, vec_c128).y, _MatC128)
62+
assert_type(solve_ivp(deriv_vec, list_float, arr_c128).y, _MatC128)
63+
64+
assert_type(solve_ivp(deriv_vec, list_float, arr_c128, t_eval=list_float).y, _MatC128)
65+
assert_type(solve_ivp(deriv_vec, list_float, list_complex, t_eval=vec_f64).y, _MatC128)
66+
assert_type(solve_ivp(deriv_vec, list_float, vec_c128, t_eval=arr_f64).y, _MatC128)

0 commit comments

Comments
 (0)