Skip to content

Commit 0126122

Browse files
authored
integrate: improved ode annotations (#766)
2 parents 8c58036 + b578c65 commit 0126122

File tree

2 files changed

+93
-61
lines changed

2 files changed

+93
-61
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,15 @@ See the `scipy` columns below for which classes are subscriptable at runtime.
170170

171171
### `scipy.integrate`
172172

173-
| generic type | `scipy-stubs` | `scipy` | |
174-
| ------------------------ | ------------- | -------- | --------------------------------------------------------------------------------------------- |
175-
| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) |
176-
| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) |
177-
| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) |
178-
| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) |
179-
| `ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
180-
| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) |
173+
| generic type | `scipy-stubs` | `scipy` | |
174+
| ----------------------------- | ----------------------- | -------- | --------------------------------------------------------------------------------------------- |
175+
| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) |
176+
| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) |
177+
| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) |
178+
| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) |
179+
| `ode[*ArgTs]` | `>=1.14.0.0, <1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
180+
| `ode[T: f64 \| c128, *ArgTs]` | `>=1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
181+
| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) |
181182

182183
### `scipy.interpolate`
183184

scipy-stubs/integrate/_ode.pyi

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from typing import Any, ClassVar, Final, Generic, Literal, Protocol, Self, TypeAlias, TypedDict, type_check_only
2+
from typing import Any, ClassVar, Final, Generic, Literal, Self, TypeAlias, TypedDict, overload, type_check_only
33
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
44

55
import numpy as np
@@ -8,8 +8,11 @@ import optype.numpy.compat as npc
88

99
__all__ = ["complex_ode", "ode"]
1010

11-
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=npc.inexact, default=np.float64 | np.complex128)
12-
_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[()]])
11+
_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[Any, ...]])
12+
_Inexact64T_co = TypeVar("_Inexact64T_co", bound=npc.inexact, default=np.float64 | np.complex128, covariant=True)
13+
14+
_IntegratorReal: TypeAlias = Literal["vode", "dopri5", "dop853", "lsoda"]
15+
_IntegratorComplex: TypeAlias = Literal["vode", "zvode"]
1316

1417
@type_check_only
1518
class _IntegratorParams(TypedDict, total=False):
@@ -34,56 +37,69 @@ class _IntegratorParams(TypedDict, total=False):
3437
beta: float
3538
verbosity: int
3639

37-
@type_check_only
38-
class _ODEFuncF(Protocol[*_Ts]):
39-
def __call__(self, t: float, y: float | onp.ArrayND[np.float64], /, *args: *_Ts) -> float | onp.ArrayND[npc.floating]: ...
40+
###
4041

41-
@type_check_only
42-
class _ODEFuncC(Protocol[*_Ts]):
43-
def __call__(
44-
self, t: float, y: complex | onp.ArrayND[np.complex128], /, *args: *_Ts
45-
) -> complex | onp.ArrayND[npc.complexfloating]: ...
42+
class IntegratorConcurrencyError(RuntimeError):
43+
def __init__(self, /, name: str) -> None: ...
4644

47-
_SolOutFunc: TypeAlias = Callable[[float, onp.Array1D[npc.inexact]], Literal[0, -1]]
45+
class ode(Generic[_Inexact64T_co, *_Ts]):
46+
f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D]
47+
jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None
48+
f_params: tuple[*_Ts]
49+
jac_params: tuple[*_Ts]
50+
stiff: Literal[0, 1]
51+
t: float
4852

49-
###
53+
def __init__(
54+
self,
55+
/,
56+
f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D],
57+
jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None = None,
58+
) -> None: ...
5059

51-
class ode(Generic[*_Ts]):
52-
stiff: int
53-
f: _ODEFuncF[*_Ts]
54-
f_params: tuple[()] | tuple[*_Ts]
55-
jac: _ODEFuncF[*_Ts] | None
56-
jac_params: tuple[()] | tuple[*_Ts]
57-
t: float
58-
def __init__(self, /, f: _ODEFuncF[*_Ts], jac: _ODEFuncF[*_Ts] | None = None) -> None: ...
60+
#
5961
@property
60-
def y(self, /) -> float: ...
61-
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> float: ...
62-
def set_initial_value(self, /, y: onp.ToComplex | onp.ToComplexND, t: float = 0.0) -> Self: ...
63-
def set_integrator(self, /, name: str, **integrator_params: Unpack[_IntegratorParams]) -> Self: ...
62+
def y(self, /) -> onp.Array1D[_Inexact64T_co]: ...
63+
64+
#
65+
@overload
66+
def set_initial_value(
67+
self: ode[np.float64, *_Ts], /, y: float | onp.ToFloat1D, t: float = 0.0
68+
) -> ode[_Inexact64T_co, *_Ts]: ...
69+
@overload
70+
def set_initial_value(
71+
self: ode[np.complex128, *_Ts], /, y: complex | onp.ToComplex1D, t: float = 0.0
72+
) -> ode[_Inexact64T_co, *_Ts]: ...
73+
74+
#
75+
@overload
76+
def set_integrator(
77+
self: ode[np.float64, *_Ts], /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams]
78+
) -> ode[_Inexact64T_co, *_Ts]: ...
79+
@overload
80+
def set_integrator(
81+
self: ode[np.complex128, *_Ts], /, name: _IntegratorComplex, **integrator_params: Unpack[_IntegratorParams]
82+
) -> ode[_Inexact64T_co, *_Ts]: ...
83+
84+
#
85+
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> onp.Array1D[_Inexact64T_co]: ...
86+
def successful(self, /) -> bool: ...
87+
def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ...
6488
def set_f_params(self, /, *args: *_Ts) -> Self: ...
6589
def set_jac_params(self, /, *args: *_Ts) -> Self: ...
66-
def set_solout(self, /, solout: _SolOutFunc) -> None: ...
67-
def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ...
68-
def successful(self, /) -> bool: ...
90+
def set_solout(self, /, solout: Callable[[float, onp.Array1D[_Inexact64T_co]], Literal[-1, 0] | None]) -> None: ...
6991

70-
class complex_ode(ode[*_Ts], Generic[*_Ts]):
71-
cf: _ODEFuncC[*_Ts]
72-
cjac: _ODEFuncC[*_Ts] | None
92+
class complex_ode(ode[np.complex128, *_Ts], Generic[*_Ts]):
93+
cf: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex1D]
94+
cjac: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex2D] | None
7395
tmp: onp.Array1D[np.float64]
74-
def __init__(self, /, f: _ODEFuncC[*_Ts], jac: _ODEFuncC[*_Ts] | None = None) -> None: ...
75-
@property
96+
7697
@override
77-
def y(self, /) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
98+
def set_integrator(self, /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams]) -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
7899
@override
79-
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
100+
def set_initial_value(self, /, y: complex | onp.ToComplex1D, t: float = 0.0) -> Self: ...
80101

81-
def find_integrator(name: str) -> type[IntegratorBase] | None: ...
82-
83-
class IntegratorConcurrencyError(RuntimeError):
84-
def __init__(self, /, name: str) -> None: ...
85-
86-
class IntegratorBase(Generic[_SCT_co]):
102+
class IntegratorBase(Generic[_Inexact64T_co]):
87103
runner: ClassVar[Callable[..., tuple[Any, ...]] | None] # fortran function or unavailable
88104
supports_run_relax: ClassVar[Literal[0, 1] | None] = None
89105
supports_step: ClassVar[Literal[0, 1] | None] = None
@@ -101,38 +117,38 @@ class IntegratorBase(Generic[_SCT_co]):
101117
def run(
102118
self,
103119
/,
104-
f: Callable[..., _SCT_co],
105-
jac: Callable[..., onp.ArrayND[_SCT_co]] | None,
120+
f: Callable[..., _Inexact64T_co],
121+
jac: Callable[..., onp.ArrayND[_Inexact64T_co]] | None,
106122
y0: complex,
107123
t0: float,
108124
t1: float,
109125
f_params: tuple[object, ...],
110126
jac_params: tuple[object, ...],
111-
) -> tuple[_SCT_co, float]: ...
127+
) -> tuple[_Inexact64T_co, float]: ...
112128
def step(
113129
self,
114130
/,
115-
f: Callable[..., _SCT_co],
116-
jac: Callable[..., onp.ArrayND[_SCT_co]],
131+
f: Callable[..., _Inexact64T_co],
132+
jac: Callable[..., onp.ArrayND[_Inexact64T_co]],
117133
y0: complex,
118134
t0: float,
119135
t1: float,
120136
f_params: tuple[object, ...],
121137
jac_params: tuple[object, ...],
122-
) -> tuple[_SCT_co, float]: ...
138+
) -> tuple[_Inexact64T_co, float]: ...
123139
def run_relax(
124140
self,
125141
/,
126-
f: Callable[..., _SCT_co],
127-
jac: Callable[..., onp.ArrayND[_SCT_co]],
142+
f: Callable[..., _Inexact64T_co],
143+
jac: Callable[..., onp.ArrayND[_Inexact64T_co]],
128144
y0: complex,
129145
t0: float,
130146
t1: float,
131147
f_params: tuple[object, ...],
132148
jac_params: tuple[object, ...],
133-
) -> tuple[_SCT_co, float]: ...
149+
) -> tuple[_Inexact64T_co, float]: ...
134150

135-
class vode(IntegratorBase[_SCT_co], Generic[_SCT_co]):
151+
class vode(IntegratorBase[_Inexact64T_co], Generic[_Inexact64T_co]):
136152
messages: ClassVar[dict[int, str]] = ...
137153

138154
active_global_handle: int
@@ -210,14 +226,16 @@ class dopri5(IntegratorBase[np.float64]):
210226
method: None = None, # unused
211227
verbosity: int = -1,
212228
) -> None: ...
213-
def set_solout(self, /, solout: _SolOutFunc | None, complex: bool = False) -> None: ...
229+
def set_solout(
230+
self, /, solout: Callable[[float, onp.Array1D[np.float64]], Literal[0, -1]] | None, complex: bool = False
231+
) -> None: ...
214232
def _solout(
215233
self,
216234
/,
217235
nr: int, # unused
218236
xold: object, # unused
219237
x: float,
220-
y: onp.Array1D[npc.floating],
238+
y: onp.Array1D[np.float64],
221239
nd: int, # unused
222240
icomp: int, # unused
223241
con: object, # unused
@@ -280,3 +298,16 @@ class lsoda(IntegratorBase[np.float64]):
280298
max_order_s: int = 5,
281299
method: None = None, # ignored
282300
) -> None: ...
301+
302+
@overload
303+
def find_integrator(name: Literal["vode"]) -> type[vode]: ...
304+
@overload
305+
def find_integrator(name: Literal["zvode"]) -> type[zvode]: ...
306+
@overload
307+
def find_integrator(name: Literal["dopri5"]) -> type[dopri5]: ...
308+
@overload
309+
def find_integrator(name: Literal["dop853"]) -> type[dop853]: ...
310+
@overload
311+
def find_integrator(name: Literal["lsoda"]) -> type[lsoda]: ...
312+
@overload
313+
def find_integrator(name: str) -> type[IntegratorBase] | None: ...

0 commit comments

Comments
 (0)