Skip to content

Commit 297f6e6

Browse files
committed
integrate: improved solve_ivp annotations
1 parent a12d226 commit 297f6e6

File tree

1 file changed

+131
-43
lines changed

1 file changed

+131
-43
lines changed

scipy-stubs/integrate/_ivp/ivp.pyi

Lines changed: 131 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Sequence
2-
from typing import Concatenate, Final, Generic, Literal, TypeAlias, overload, type_check_only
2+
from typing import Any, Final, Generic, Literal, TypeAlias, TypeVarTuple, overload, type_check_only
33
from typing_extensions import TypeVar, TypedDict, Unpack
44

55
import numpy as np
@@ -10,58 +10,63 @@ from .base import DenseOutput, OdeSolver
1010
from .common import OdeSolution
1111
from scipy._lib._util import _RichResult
1212
from scipy.sparse import sparray, spmatrix
13+
from scipy.sparse._base import _spbase
1314

14-
_SCT_cf = TypeVar("_SCT_cf", bound=npc.inexact, default=np.float64 | np.complex128)
15+
_Ts = TypeVarTuple("_Ts")
16+
_ScalarT = TypeVar("_ScalarT", bound=npc.number | np.bool)
17+
_Inexact64T = TypeVar("_Inexact64T", bound=np.float64 | np.complex128)
18+
_Inexact64T_co = TypeVar("_Inexact64T_co", bound=np.float64 | np.complex128, default=np.float64 | np.complex128, covariant=True)
1519

16-
_FuncSol: TypeAlias = Callable[[float], onp.ArrayND[_SCT_cf]]
17-
_FuncEvent: TypeAlias = Callable[[float, onp.ArrayND[_SCT_cf]], float]
18-
_Events: TypeAlias = Sequence[_FuncEvent[_SCT_cf]]
20+
_FuncSol: TypeAlias = Callable[[np.float64], onp.ArrayND[_Inexact64T]]
21+
_FuncEvent: TypeAlias = Callable[[np.float64, onp.ArrayND[_Inexact64T], *_Ts], float]
22+
_Events: TypeAlias = Sequence[_FuncEvent[_Inexact64T, *_Ts]] | _FuncEvent[_Inexact64T, *_Ts]
1923

20-
_Int1D: TypeAlias = onp.Array1D[np.intp]
24+
_Int1D: TypeAlias = onp.Array1D[np.int_]
2125
_Float1D: TypeAlias = onp.Array1D[np.float64]
26+
_Float2D: TypeAlias = onp.Array2D[np.float64]
27+
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
28+
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
2229

23-
_ToJac: TypeAlias = onp.ToComplex2D | spmatrix | sparray
30+
_Sparse2D: TypeAlias = _spbase[_ScalarT, tuple[int, int]] | sparray[_ScalarT, tuple[int, int]] | spmatrix[_ScalarT]
31+
_ToJac: TypeAlias = onp.ToArray2D[complex, npc.inexact] | _Sparse2D[npc.inexact]
2432

25-
_IVPMethod: TypeAlias = Literal["RK23", "RK45", "DOP853", "Radau", "BDF", "LSODA"]
33+
_IVPMethod: TypeAlias = Literal["RK23", "RK45", "DOP853", "Radau", "BDF", "LSODA"] | type[OdeSolver]
2634

2735
@type_check_only
28-
class _SolverOptions(TypedDict, Generic[_SCT_cf], total=False):
29-
first_step: onp.ToFloat | None
30-
max_step: onp.ToFloat
31-
rtol: onp.ToFloat | onp.ToFloat1D
32-
atol: onp.ToFloat | onp.ToFloat1D
33-
jac: _ToJac | Callable[[float, onp.Array1D[np.float64]], _ToJac] | None
34-
jac_sparsity: onp.ToFloat2D | spmatrix | sparray | None
35-
lband: onp.ToInt | None
36-
uband: onp.ToInt | None
37-
min_step: onp.ToFloat
36+
class _SolverOptions(TypedDict, total=False):
37+
first_step: float | None
38+
max_step: float
39+
rtol: float | onp.ToFloat1D
40+
atol: float | onp.ToFloat1D
41+
jac: _ToJac | Callable[[np.float64, onp.Array1D], _ToJac] | None
42+
jac_sparsity: onp.ToFloat2D | _Sparse2D[npc.floating] | None
43+
lband: int | None
44+
uband: int | None
45+
min_step: float
3846

3947
###
4048

4149
METHODS: Final[dict[str, type]] = ...
4250
MESSAGES: Final[dict[int, str]] = ...
4351

44-
class OdeResult(
45-
_RichResult[int | str | onp.ArrayND[np.float64 | _SCT_cf] | list[onp.ArrayND[np.float64 | _SCT_cf]] | OdeSolution | None],
46-
Generic[_SCT_cf],
47-
):
52+
class OdeResult(_RichResult[Any], Generic[_Inexact64T_co]):
4853
t: _Float1D
49-
y: onp.Array2D[_SCT_cf]
54+
y: onp.Array2D[_Inexact64T_co]
5055
sol: OdeSolution | None
5156
t_events: list[_Float1D] | None
52-
y_events: list[onp.ArrayND[_SCT_cf]] | None
57+
y_events: list[onp.ArrayND[_Inexact64T_co]] | None
5358
nfev: int
5459
njev: int
5560
nlu: int
5661
status: Literal[-1, 0, 1]
5762
message: str
5863
success: bool
5964

60-
def prepare_events(events: _FuncEvent[_SCT_cf] | _Events[_SCT_cf]) -> tuple[_Events[_SCT_cf], _Float1D, _Float1D]: ...
61-
def solve_event_equation(event: _FuncEvent[_SCT_cf], sol: _FuncSol[_SCT_cf], t_old: float, t: float) -> float: ...
65+
def prepare_events(events: _Events[_Inexact64T]) -> tuple[_Events[_Inexact64T], _Float1D, _Float1D]: ...
66+
def solve_event_equation(event: _FuncEvent[_Inexact64T], sol: _FuncSol[_Inexact64T], t_old: float, t: float) -> float: ...
6267
def handle_events(
6368
sol: DenseOutput,
64-
events: Sequence[_FuncEvent[_SCT_cf]],
69+
events: Sequence[_FuncEvent[_Inexact64T]],
6570
active_events: onp.ArrayND[np.intp],
6671
event_count: onp.ArrayND[np.intp | np.float64],
6772
max_events: onp.ArrayND[np.intp | np.float64],
@@ -71,30 +76,113 @@ def handle_events(
7176
def find_active_events(g: onp.ToFloat1D, g_new: onp.ToFloat1D, direction: onp.ArrayND[np.float64]) -> _Int1D: ...
7277

7378
#
74-
@overload
79+
@overload # float, vectorized=False (default), args=None (default)
7580
def solve_ivp(
76-
fun: Callable[Concatenate[float, onp.Array1D[_SCT_cf], ...], onp.ArrayND[_SCT_cf]],
77-
t_span: Sequence[onp.ToFloat],
78-
y0: onp.ToArray1D,
79-
method: _IVPMethod | type[OdeSolver] = "RK45",
81+
fun: Callable[[np.float64, _Float1D], onp.ToFloat1D | float],
82+
t_span: Sequence[float],
83+
y0: onp.ToFloat1D,
84+
method: _IVPMethod = "RK45",
8085
t_eval: onp.ToFloat1D | None = None,
8186
dense_output: bool = False,
82-
events: _Events[_SCT_cf] | None = None,
87+
events: _Events[np.float64] | None = None,
8388
vectorized: onp.ToFalse = False,
84-
args: tuple[object, ...] | None = None,
89+
args: None = None,
8590
**options: Unpack[_SolverOptions],
86-
) -> OdeResult[_SCT_cf]: ...
87-
@overload
91+
) -> OdeResult[np.float64]: ...
92+
@overload # float, vectorized=False (default), args=<given>
8893
def solve_ivp(
89-
fun: Callable[Concatenate[_Float1D, onp.Array2D[_SCT_cf], ...], onp.ArrayND[_SCT_cf]],
90-
t_span: Sequence[onp.ToFloat],
91-
y0: onp.ToArray1D,
92-
method: _IVPMethod | type[OdeSolver] = "RK45",
94+
fun: Callable[[np.float64, _Float1D, *_Ts], onp.ToFloat1D | float],
95+
t_span: Sequence[float],
96+
y0: onp.ToFloat1D,
97+
method: _IVPMethod = "RK45",
9398
t_eval: onp.ToFloat1D | None = None,
9499
dense_output: bool = False,
95-
events: _Events[_SCT_cf] | None = None,
100+
events: _Events[np.float64] | None = None,
101+
vectorized: onp.ToFalse = False,
102+
*,
103+
args: tuple[*_Ts],
104+
**options: Unpack[_SolverOptions],
105+
) -> OdeResult[np.float64]: ...
106+
@overload # float, vectorized=True, args=None (default)
107+
def solve_ivp(
108+
fun: Callable[[_Float1D, _Float2D], onp.ToFloat2D],
109+
t_span: Sequence[float],
110+
y0: onp.ToFloat1D,
111+
method: _IVPMethod = "RK45",
112+
t_eval: onp.ToFloat1D | None = None,
113+
dense_output: bool = False,
114+
events: _Events[np.float64] | None = None,
115+
*,
116+
vectorized: onp.ToTrue,
117+
args: None = None,
118+
**options: Unpack[_SolverOptions],
119+
) -> OdeResult[np.float64]: ...
120+
@overload # float, vectorized=True, args=<given>
121+
def solve_ivp(
122+
fun: Callable[[_Float1D, _Float2D, *_Ts], onp.ToFloat2D],
123+
t_span: Sequence[float],
124+
y0: onp.ToFloat1D,
125+
method: _IVPMethod = "RK45",
126+
t_eval: onp.ToFloat1D | None = None,
127+
dense_output: bool = False,
128+
events: _Events[np.float64] | None = None,
129+
*,
130+
vectorized: onp.ToTrue,
131+
args: tuple[*_Ts],
132+
**options: Unpack[_SolverOptions],
133+
) -> OdeResult[np.float64]: ...
134+
@overload # complex, vectorized=False (default), args=None (default)
135+
def solve_ivp(
136+
fun: Callable[[np.float64, _Complex1D], onp.ToComplex1D | complex],
137+
t_span: Sequence[float],
138+
y0: onp.ToComplex1D,
139+
method: _IVPMethod = "RK45",
140+
t_eval: onp.ToFloat1D | None = None,
141+
dense_output: bool = False,
142+
events: _Events[np.complex128] | None = None,
143+
vectorized: onp.ToFalse = False,
144+
args: None = None,
145+
**options: Unpack[_SolverOptions],
146+
) -> OdeResult[np.complex128]: ...
147+
@overload # complex, vectorized=False (default), args=<given>
148+
def solve_ivp(
149+
fun: Callable[[np.float64, _Complex1D, *_Ts], onp.ToComplex1D | complex],
150+
t_span: Sequence[float],
151+
y0: onp.ToComplex1D,
152+
method: _IVPMethod = "RK45",
153+
t_eval: onp.ToFloat1D | None = None,
154+
dense_output: bool = False,
155+
events: _Events[np.complex128] | None = None,
156+
vectorized: onp.ToFalse = False,
157+
*,
158+
args: tuple[*_Ts],
159+
**options: Unpack[_SolverOptions],
160+
) -> OdeResult[np.complex128]: ...
161+
@overload # complex, vectorized=True, args=None (default)
162+
def solve_ivp(
163+
fun: Callable[[_Float1D, _Complex2D], onp.ToComplex2D],
164+
t_span: Sequence[float],
165+
y0: onp.ToComplex1D,
166+
method: _IVPMethod = "RK45",
167+
t_eval: onp.ToFloat1D | None = None,
168+
dense_output: bool = False,
169+
events: _Events[np.complex128] | None = None,
170+
*,
171+
vectorized: onp.ToTrue,
172+
args: None = None,
173+
**options: Unpack[_SolverOptions],
174+
) -> OdeResult[np.complex128]: ...
175+
@overload # complex, vectorized=True, args=<given>
176+
def solve_ivp(
177+
fun: Callable[[_Float1D, _Complex2D, *_Ts], onp.ToComplex2D],
178+
t_span: Sequence[float],
179+
y0: onp.ToComplex1D,
180+
method: _IVPMethod = "RK45",
181+
t_eval: onp.ToFloat1D | None = None,
182+
dense_output: bool = False,
183+
events: _Events[np.complex128] | None = None,
96184
*,
97185
vectorized: onp.ToTrue,
98-
args: tuple[object, ...] | None = None,
186+
args: tuple[*_Ts],
99187
**options: Unpack[_SolverOptions],
100-
) -> OdeResult[_SCT_cf]: ...
188+
) -> OdeResult[np.complex128]: ...

0 commit comments

Comments
 (0)