diff --git a/README.md b/README.md index d10b3031..8f48937a 100644 --- a/README.md +++ b/README.md @@ -170,14 +170,15 @@ See the `scipy` columns below for which classes are subscriptable at runtime. ### `scipy.integrate` -| generic type | `scipy-stubs` | `scipy` | | -| ------------------------ | ------------- | -------- | --------------------------------------------------------------------------------------------- | -| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) | -| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) | -| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) | -| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) | -| `ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) | -| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) | +| generic type | `scipy-stubs` | `scipy` | | +| ----------------------------- | ----------------------- | -------- | --------------------------------------------------------------------------------------------- | +| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) | +| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) | +| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) | +| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) | +| `ode[*ArgTs]` | `>=1.14.0.0, <1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) | +| `ode[T: f64 \| c128, *ArgTs]` | `>=1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) | +| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) | ### `scipy.interpolate` diff --git a/scipy-stubs/integrate/_ode.pyi b/scipy-stubs/integrate/_ode.pyi index b116ac99..5aa75e21 100644 --- a/scipy-stubs/integrate/_ode.pyi +++ b/scipy-stubs/integrate/_ode.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, ClassVar, Final, Generic, Literal, Protocol, Self, TypeAlias, TypedDict, type_check_only +from typing import Any, ClassVar, Final, Generic, Literal, Self, TypeAlias, TypedDict, overload, type_check_only from typing_extensions import TypeVar, TypeVarTuple, Unpack, override import numpy as np @@ -8,8 +8,11 @@ import optype.numpy.compat as npc __all__ = ["complex_ode", "ode"] -_SCT_co = TypeVar("_SCT_co", covariant=True, bound=npc.inexact, default=np.float64 | np.complex128) -_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[()]]) +_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[Any, ...]]) +_Inexact64T_co = TypeVar("_Inexact64T_co", bound=npc.inexact, default=np.float64 | np.complex128, covariant=True) + +_IntegratorReal: TypeAlias = Literal["vode", "dopri5", "dop853", "lsoda"] +_IntegratorComplex: TypeAlias = Literal["vode", "zvode"] @type_check_only class _IntegratorParams(TypedDict, total=False): @@ -34,56 +37,69 @@ class _IntegratorParams(TypedDict, total=False): beta: float verbosity: int -@type_check_only -class _ODEFuncF(Protocol[*_Ts]): - def __call__(self, t: float, y: float | onp.ArrayND[np.float64], /, *args: *_Ts) -> float | onp.ArrayND[npc.floating]: ... +### -@type_check_only -class _ODEFuncC(Protocol[*_Ts]): - def __call__( - self, t: float, y: complex | onp.ArrayND[np.complex128], /, *args: *_Ts - ) -> complex | onp.ArrayND[npc.complexfloating]: ... +class IntegratorConcurrencyError(RuntimeError): + def __init__(self, /, name: str) -> None: ... -_SolOutFunc: TypeAlias = Callable[[float, onp.Array1D[npc.inexact]], Literal[0, -1]] +class ode(Generic[_Inexact64T_co, *_Ts]): + f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D] + jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None + f_params: tuple[*_Ts] + jac_params: tuple[*_Ts] + stiff: Literal[0, 1] + t: float -### + def __init__( + self, + /, + f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D], + jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None = None, + ) -> None: ... -class ode(Generic[*_Ts]): - stiff: int - f: _ODEFuncF[*_Ts] - f_params: tuple[()] | tuple[*_Ts] - jac: _ODEFuncF[*_Ts] | None - jac_params: tuple[()] | tuple[*_Ts] - t: float - def __init__(self, /, f: _ODEFuncF[*_Ts], jac: _ODEFuncF[*_Ts] | None = None) -> None: ... + # @property - def y(self, /) -> float: ... - def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> float: ... - def set_initial_value(self, /, y: onp.ToComplex | onp.ToComplexND, t: float = 0.0) -> Self: ... - def set_integrator(self, /, name: str, **integrator_params: Unpack[_IntegratorParams]) -> Self: ... + def y(self, /) -> onp.Array1D[_Inexact64T_co]: ... + + # + @overload + def set_initial_value( + self: ode[np.float64, *_Ts], /, y: float | onp.ToFloat1D, t: float = 0.0 + ) -> ode[_Inexact64T_co, *_Ts]: ... + @overload + def set_initial_value( + self: ode[np.complex128, *_Ts], /, y: complex | onp.ToComplex1D, t: float = 0.0 + ) -> ode[_Inexact64T_co, *_Ts]: ... + + # + @overload + def set_integrator( + self: ode[np.float64, *_Ts], /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams] + ) -> ode[_Inexact64T_co, *_Ts]: ... + @overload + def set_integrator( + self: ode[np.complex128, *_Ts], /, name: _IntegratorComplex, **integrator_params: Unpack[_IntegratorParams] + ) -> ode[_Inexact64T_co, *_Ts]: ... + + # + def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> onp.Array1D[_Inexact64T_co]: ... + def successful(self, /) -> bool: ... + def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ... def set_f_params(self, /, *args: *_Ts) -> Self: ... def set_jac_params(self, /, *args: *_Ts) -> Self: ... - def set_solout(self, /, solout: _SolOutFunc) -> None: ... - def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ... - def successful(self, /) -> bool: ... + def set_solout(self, /, solout: Callable[[float, onp.Array1D[_Inexact64T_co]], Literal[-1, 0] | None]) -> None: ... -class complex_ode(ode[*_Ts], Generic[*_Ts]): - cf: _ODEFuncC[*_Ts] - cjac: _ODEFuncC[*_Ts] | None +class complex_ode(ode[np.complex128, *_Ts], Generic[*_Ts]): + cf: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex1D] + cjac: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex2D] | None tmp: onp.Array1D[np.float64] - def __init__(self, /, f: _ODEFuncC[*_Ts], jac: _ODEFuncC[*_Ts] | None = None) -> None: ... - @property + @override - def y(self, /) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def set_integrator(self, /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams]) -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] @override - def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + def set_initial_value(self, /, y: complex | onp.ToComplex1D, t: float = 0.0) -> Self: ... -def find_integrator(name: str) -> type[IntegratorBase] | None: ... - -class IntegratorConcurrencyError(RuntimeError): - def __init__(self, /, name: str) -> None: ... - -class IntegratorBase(Generic[_SCT_co]): +class IntegratorBase(Generic[_Inexact64T_co]): runner: ClassVar[Callable[..., tuple[Any, ...]] | None] # fortran function or unavailable supports_run_relax: ClassVar[Literal[0, 1] | None] = None supports_step: ClassVar[Literal[0, 1] | None] = None @@ -101,38 +117,38 @@ class IntegratorBase(Generic[_SCT_co]): def run( self, /, - f: Callable[..., _SCT_co], - jac: Callable[..., onp.ArrayND[_SCT_co]] | None, + f: Callable[..., _Inexact64T_co], + jac: Callable[..., onp.ArrayND[_Inexact64T_co]] | None, y0: complex, t0: float, t1: float, f_params: tuple[object, ...], jac_params: tuple[object, ...], - ) -> tuple[_SCT_co, float]: ... + ) -> tuple[_Inexact64T_co, float]: ... def step( self, /, - f: Callable[..., _SCT_co], - jac: Callable[..., onp.ArrayND[_SCT_co]], + f: Callable[..., _Inexact64T_co], + jac: Callable[..., onp.ArrayND[_Inexact64T_co]], y0: complex, t0: float, t1: float, f_params: tuple[object, ...], jac_params: tuple[object, ...], - ) -> tuple[_SCT_co, float]: ... + ) -> tuple[_Inexact64T_co, float]: ... def run_relax( self, /, - f: Callable[..., _SCT_co], - jac: Callable[..., onp.ArrayND[_SCT_co]], + f: Callable[..., _Inexact64T_co], + jac: Callable[..., onp.ArrayND[_Inexact64T_co]], y0: complex, t0: float, t1: float, f_params: tuple[object, ...], jac_params: tuple[object, ...], - ) -> tuple[_SCT_co, float]: ... + ) -> tuple[_Inexact64T_co, float]: ... -class vode(IntegratorBase[_SCT_co], Generic[_SCT_co]): +class vode(IntegratorBase[_Inexact64T_co], Generic[_Inexact64T_co]): messages: ClassVar[dict[int, str]] = ... active_global_handle: int @@ -210,14 +226,16 @@ class dopri5(IntegratorBase[np.float64]): method: None = None, # unused verbosity: int = -1, ) -> None: ... - def set_solout(self, /, solout: _SolOutFunc | None, complex: bool = False) -> None: ... + def set_solout( + self, /, solout: Callable[[float, onp.Array1D[np.float64]], Literal[0, -1]] | None, complex: bool = False + ) -> None: ... def _solout( self, /, nr: int, # unused xold: object, # unused x: float, - y: onp.Array1D[npc.floating], + y: onp.Array1D[np.float64], nd: int, # unused icomp: int, # unused con: object, # unused @@ -280,3 +298,16 @@ class lsoda(IntegratorBase[np.float64]): max_order_s: int = 5, method: None = None, # ignored ) -> None: ... + +@overload +def find_integrator(name: Literal["vode"]) -> type[vode]: ... +@overload +def find_integrator(name: Literal["zvode"]) -> type[zvode]: ... +@overload +def find_integrator(name: Literal["dopri5"]) -> type[dopri5]: ... +@overload +def find_integrator(name: Literal["dop853"]) -> type[dop853]: ... +@overload +def find_integrator(name: Literal["lsoda"]) -> type[lsoda]: ... +@overload +def find_integrator(name: str) -> type[IntegratorBase] | None: ...