Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,15 @@ See the `scipy` columns below for which classes are subscriptable at runtime.

### `scipy.integrate`

| generic type | `scipy-stubs` | `scipy` | |
| ------------------------ | ------------- | -------- | --------------------------------------------------------------------------------------------- |
| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) |
| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) |
| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) |
| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) |
| `ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) |
| generic type | `scipy-stubs` | `scipy` | |
| ----------------------------- | ----------------------- | -------- | --------------------------------------------------------------------------------------------- |
| `BDF[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.BDF.html) |
| `DOP853[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) |
| `RK23[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK23.html) |
| `RK45[T: f64 \| c128]` | `>=1.14.0.1` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.RK45.html) |
| `ode[*ArgTs]` | `>=1.14.0.0, <1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
| `ode[T: f64 \| c128, *ArgTs]` | `>=1.16.0.3` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.ode.html) |
| `complex_ode[*ArgTs]` | `>=1.14.0.0` | `>=1.17` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.complex_ode.html) |

### `scipy.interpolate`

Expand Down
137 changes: 84 additions & 53 deletions scipy-stubs/integrate/_ode.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, ClassVar, Final, Generic, Literal, Protocol, Self, TypeAlias, TypedDict, type_check_only
from typing import Any, ClassVar, Final, Generic, Literal, Self, TypeAlias, TypedDict, overload, type_check_only
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override

import numpy as np
Expand All @@ -8,8 +8,11 @@ import optype.numpy.compat as npc

__all__ = ["complex_ode", "ode"]

_SCT_co = TypeVar("_SCT_co", covariant=True, bound=npc.inexact, default=np.float64 | np.complex128)
_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[()]])
_Ts = TypeVarTuple("_Ts", default=Unpack[tuple[Any, ...]])
_Inexact64T_co = TypeVar("_Inexact64T_co", bound=npc.inexact, default=np.float64 | np.complex128, covariant=True)

_IntegratorReal: TypeAlias = Literal["vode", "dopri5", "dop853", "lsoda"]
_IntegratorComplex: TypeAlias = Literal["vode", "zvode"]

@type_check_only
class _IntegratorParams(TypedDict, total=False):
Expand All @@ -34,56 +37,69 @@ class _IntegratorParams(TypedDict, total=False):
beta: float
verbosity: int

@type_check_only
class _ODEFuncF(Protocol[*_Ts]):
def __call__(self, t: float, y: float | onp.ArrayND[np.float64], /, *args: *_Ts) -> float | onp.ArrayND[npc.floating]: ...
###

@type_check_only
class _ODEFuncC(Protocol[*_Ts]):
def __call__(
self, t: float, y: complex | onp.ArrayND[np.complex128], /, *args: *_Ts
) -> complex | onp.ArrayND[npc.complexfloating]: ...
class IntegratorConcurrencyError(RuntimeError):
def __init__(self, /, name: str) -> None: ...

_SolOutFunc: TypeAlias = Callable[[float, onp.Array1D[npc.inexact]], Literal[0, -1]]
class ode(Generic[_Inexact64T_co, *_Ts]):
f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D]
jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None
f_params: tuple[*_Ts]
jac_params: tuple[*_Ts]
stiff: Literal[0, 1]
t: float

###
def __init__(
self,
/,
f: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex1D],
jac: Callable[[float, onp.Array1D[_Inexact64T_co], *_Ts], complex | onp.ToComplex2D] | None = None,
) -> None: ...

class ode(Generic[*_Ts]):
stiff: int
f: _ODEFuncF[*_Ts]
f_params: tuple[()] | tuple[*_Ts]
jac: _ODEFuncF[*_Ts] | None
jac_params: tuple[()] | tuple[*_Ts]
t: float
def __init__(self, /, f: _ODEFuncF[*_Ts], jac: _ODEFuncF[*_Ts] | None = None) -> None: ...
#
@property
def y(self, /) -> float: ...
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> float: ...
def set_initial_value(self, /, y: onp.ToComplex | onp.ToComplexND, t: float = 0.0) -> Self: ...
def set_integrator(self, /, name: str, **integrator_params: Unpack[_IntegratorParams]) -> Self: ...
def y(self, /) -> onp.Array1D[_Inexact64T_co]: ...

#
@overload
def set_initial_value(
self: ode[np.float64, *_Ts], /, y: float | onp.ToFloat1D, t: float = 0.0
) -> ode[_Inexact64T_co, *_Ts]: ...
@overload
def set_initial_value(
self: ode[np.complex128, *_Ts], /, y: complex | onp.ToComplex1D, t: float = 0.0
) -> ode[_Inexact64T_co, *_Ts]: ...

#
@overload
def set_integrator(
self: ode[np.float64, *_Ts], /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams]
) -> ode[_Inexact64T_co, *_Ts]: ...
@overload
def set_integrator(
self: ode[np.complex128, *_Ts], /, name: _IntegratorComplex, **integrator_params: Unpack[_IntegratorParams]
) -> ode[_Inexact64T_co, *_Ts]: ...

#
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> onp.Array1D[_Inexact64T_co]: ...
def successful(self, /) -> bool: ...
def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ...
def set_f_params(self, /, *args: *_Ts) -> Self: ...
def set_jac_params(self, /, *args: *_Ts) -> Self: ...
def set_solout(self, /, solout: _SolOutFunc) -> None: ...
def get_return_code(self, /) -> Literal[-7, -6, -5, -4, -3, -2, -1, 1, 2]: ...
def successful(self, /) -> bool: ...
def set_solout(self, /, solout: Callable[[float, onp.Array1D[_Inexact64T_co]], Literal[-1, 0] | None]) -> None: ...

class complex_ode(ode[*_Ts], Generic[*_Ts]):
cf: _ODEFuncC[*_Ts]
cjac: _ODEFuncC[*_Ts] | None
class complex_ode(ode[np.complex128, *_Ts], Generic[*_Ts]):
cf: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex1D]
cjac: Callable[[float, onp.Array1D[np.complex128], *_Ts], complex | onp.ToComplex2D] | None
tmp: onp.Array1D[np.float64]
def __init__(self, /, f: _ODEFuncC[*_Ts], jac: _ODEFuncC[*_Ts] | None = None) -> None: ...
@property

@override
def y(self, /) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def set_integrator(self, /, name: _IntegratorReal, **integrator_params: Unpack[_IntegratorParams]) -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@override
def integrate(self, /, t: float, step: bool = False, relax: bool = False) -> complex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def set_initial_value(self, /, y: complex | onp.ToComplex1D, t: float = 0.0) -> Self: ...

def find_integrator(name: str) -> type[IntegratorBase] | None: ...

class IntegratorConcurrencyError(RuntimeError):
def __init__(self, /, name: str) -> None: ...

class IntegratorBase(Generic[_SCT_co]):
class IntegratorBase(Generic[_Inexact64T_co]):
runner: ClassVar[Callable[..., tuple[Any, ...]] | None] # fortran function or unavailable
supports_run_relax: ClassVar[Literal[0, 1] | None] = None
supports_step: ClassVar[Literal[0, 1] | None] = None
Expand All @@ -101,38 +117,38 @@ class IntegratorBase(Generic[_SCT_co]):
def run(
self,
/,
f: Callable[..., _SCT_co],
jac: Callable[..., onp.ArrayND[_SCT_co]] | None,
f: Callable[..., _Inexact64T_co],
jac: Callable[..., onp.ArrayND[_Inexact64T_co]] | None,
y0: complex,
t0: float,
t1: float,
f_params: tuple[object, ...],
jac_params: tuple[object, ...],
) -> tuple[_SCT_co, float]: ...
) -> tuple[_Inexact64T_co, float]: ...
def step(
self,
/,
f: Callable[..., _SCT_co],
jac: Callable[..., onp.ArrayND[_SCT_co]],
f: Callable[..., _Inexact64T_co],
jac: Callable[..., onp.ArrayND[_Inexact64T_co]],
y0: complex,
t0: float,
t1: float,
f_params: tuple[object, ...],
jac_params: tuple[object, ...],
) -> tuple[_SCT_co, float]: ...
) -> tuple[_Inexact64T_co, float]: ...
def run_relax(
self,
/,
f: Callable[..., _SCT_co],
jac: Callable[..., onp.ArrayND[_SCT_co]],
f: Callable[..., _Inexact64T_co],
jac: Callable[..., onp.ArrayND[_Inexact64T_co]],
y0: complex,
t0: float,
t1: float,
f_params: tuple[object, ...],
jac_params: tuple[object, ...],
) -> tuple[_SCT_co, float]: ...
) -> tuple[_Inexact64T_co, float]: ...

class vode(IntegratorBase[_SCT_co], Generic[_SCT_co]):
class vode(IntegratorBase[_Inexact64T_co], Generic[_Inexact64T_co]):
messages: ClassVar[dict[int, str]] = ...

active_global_handle: int
Expand Down Expand Up @@ -210,14 +226,16 @@ class dopri5(IntegratorBase[np.float64]):
method: None = None, # unused
verbosity: int = -1,
) -> None: ...
def set_solout(self, /, solout: _SolOutFunc | None, complex: bool = False) -> None: ...
def set_solout(
self, /, solout: Callable[[float, onp.Array1D[np.float64]], Literal[0, -1]] | None, complex: bool = False
) -> None: ...
def _solout(
self,
/,
nr: int, # unused
xold: object, # unused
x: float,
y: onp.Array1D[npc.floating],
y: onp.Array1D[np.float64],
nd: int, # unused
icomp: int, # unused
con: object, # unused
Expand Down Expand Up @@ -280,3 +298,16 @@ class lsoda(IntegratorBase[np.float64]):
max_order_s: int = 5,
method: None = None, # ignored
) -> None: ...

@overload
def find_integrator(name: Literal["vode"]) -> type[vode]: ...
@overload
def find_integrator(name: Literal["zvode"]) -> type[zvode]: ...
@overload
def find_integrator(name: Literal["dopri5"]) -> type[dopri5]: ...
@overload
def find_integrator(name: Literal["dop853"]) -> type[dop853]: ...
@overload
def find_integrator(name: Literal["lsoda"]) -> type[lsoda]: ...
@overload
def find_integrator(name: str) -> type[IntegratorBase] | None: ...