Skip to content

Commit 9799dc6

Browse files
authored
🐛 integrate: several OdeSolver fixes (#880)
2 parents 7c47019 + 06bbf89 commit 9799dc6

File tree

7 files changed

+49
-39
lines changed

7 files changed

+49
-39
lines changed
Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,85 @@
11
from collections.abc import Callable
2-
from typing import Any, ClassVar, Final, Generic, Literal, TypeVar, overload
2+
from typing import Any, ClassVar, Final, Generic, Literal, TypeAlias, overload
3+
from typing_extensions import TypeVar
34

45
import numpy as np
56
import optype.numpy as onp
67
import optype.numpy.compat as npc
78

8-
_VT = TypeVar("_VT", bound=onp.ArrayND[npc.inexact], default=onp.ArrayND[Any])
9+
_ScalarT = TypeVar("_ScalarT", bound=np.float64 | np.complex128, default=np.float64)
10+
_ScalarT_co = TypeVar("_ScalarT_co", bound=npc.inexact, default=np.float64 | Any, covariant=True)
911

10-
class OdeSolver:
12+
_ToFunReal: TypeAlias = Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND]
13+
_ToFunComplex: TypeAlias = Callable[[float, onp.ArrayND[np.complex128]], onp.ToComplexND]
14+
15+
@overload
16+
def check_arguments(
17+
fun: _ToFunReal, y0: onp.ToFloatND, support_complex: bool
18+
) -> Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]: ...
19+
@overload
20+
def check_arguments(
21+
fun: _ToFunComplex, y0: onp.ToJustComplexND, support_complex: Literal[True]
22+
) -> Callable[[float, onp.ArrayND[np.complex128]], onp.ArrayND[np.complex128]]: ...
23+
24+
class OdeSolver(Generic[_ScalarT]):
1125
TOO_SMALL_STEP: ClassVar[str] = ...
1226

1327
t: float
1428
t_old: float
1529
t_bound: float
30+
y: onp.ArrayND[_ScalarT]
1631
vectorized: bool
17-
fun: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]
18-
fun_single: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]
19-
fun_vectorized: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]
20-
direction: float
21-
n: int
32+
fun: Callable[[float, onp.ArrayND[_ScalarT]], onp.ArrayND[_ScalarT]]
33+
fun_single: Callable[[float, onp.Array1D[_ScalarT]], onp.Array1D[_ScalarT]]
34+
fun_vectorized: Callable[[float, onp.Array2D[_ScalarT]], onp.Array2D[_ScalarT]]
35+
direction: np.float64
2236
status: Literal["running", "finished", "failed"]
37+
n: int
2338
nfev: int
2439
njev: int
2540
nlu: int
2641

2742
@overload
2843
def __init__(
29-
self,
44+
self: OdeSolver[np.float64],
3045
/,
31-
fun: Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND],
32-
t0: onp.ToFloatND,
46+
fun: _ToFunReal,
47+
t0: float,
3348
y0: onp.ToFloatND,
34-
t_bound: onp.ToFloat,
49+
t_bound: float,
3550
vectorized: bool,
3651
support_complex: onp.ToBool = False,
3752
) -> None: ...
3853
@overload
3954
def __init__(
40-
self,
55+
self: OdeSolver[np.complex128],
4156
/,
42-
fun: Callable[[float, onp.ArrayND[np.float64 | np.complex128]], onp.ToComplexND],
43-
t0: onp.ToFloat,
44-
y0: onp.ToComplexND,
57+
fun: _ToFunComplex,
58+
t0: float,
59+
y0: onp.ToJustComplexND,
4560
t_bound: onp.ToFloat,
4661
vectorized: bool,
4762
support_complex: onp.ToTrue,
4863
) -> None: ...
4964
@property
50-
def step_size(self, /) -> float | None: ...
65+
def step_size(self, /) -> np.float64 | None: ...
5166
def step(self, /) -> str | None: ...
52-
def dense_output(self, /) -> ConstantDenseOutput: ...
67+
def dense_output(self, /) -> ConstantDenseOutput[_ScalarT]: ...
5368

54-
class DenseOutput:
69+
class DenseOutput(Generic[_ScalarT_co]):
5570
t_old: Final[float]
5671
t: Final[float]
5772
t_min: Final[float]
5873
t_max: Final[float]
5974

60-
def __init__(self, /, t_old: onp.ToFloat, t: onp.ToFloat) -> None: ...
75+
def __init__(self, /, t_old: float, t: float) -> None: ...
76+
77+
#
6178
@overload
62-
def __call__(self, /, t: onp.ToFloat) -> onp.Array1D[npc.inexact]: ...
79+
def __call__(self, /, t: onp.ToFloat) -> onp.Array1D[_ScalarT_co]: ...
6380
@overload
64-
def __call__(self, /, t: onp.ToFloatND) -> onp.ArrayND[npc.inexact]: ...
81+
def __call__(self, /, t: onp.ToFloatND) -> onp.ArrayND[_ScalarT_co]: ...
6582

66-
class ConstantDenseOutput(DenseOutput, Generic[_VT]):
67-
value: _VT
68-
def __init__(self, /, t_old: onp.ToFloat, t: onp.ToFloat, value: _VT) -> None: ...
69-
70-
def check_arguments(
71-
fun: Callable[[float, onp.ArrayND[np.float64]], onp.ToComplexND], y0: onp.ToComplexND, support_complex: bool
72-
) -> (
73-
Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]
74-
| Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.complex128]]
75-
): ...
83+
class ConstantDenseOutput(DenseOutput[_ScalarT_co], Generic[_ScalarT_co]):
84+
value: onp.ArrayND[_ScalarT_co]
85+
def __init__(self, /, t_old: float, t: float, value: onp.ArrayND[_ScalarT_co]) -> None: ...

scipy-stubs/integrate/_ivp/bdf.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class BDF(OdeSolver, Generic[_SCT_co]):
6363
**extraneous: Never,
6464
) -> None: ...
6565

66-
class BdfDenseOutput(DenseOutput):
66+
class BdfDenseOutput(DenseOutput[np.float64]):
6767
order: int
6868
t_shift: onp.ArrayND[np.float64]
6969
denom: onp.ArrayND[np.float64]

scipy-stubs/integrate/_ivp/common.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ from scipy.sparse import csc_matrix
1212

1313
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
1414
_ToFloatT = TypeVar("_ToFloatT", bound=onp.ToFloat)
15-
_InterpT_co = TypeVar("_InterpT_co", bound=DenseOutput, default=DenseOutput, covariant=True)
15+
_InterpT_co = TypeVar("_InterpT_co", bound=DenseOutput[npc.inexact], default=DenseOutput[Any], covariant=True)
1616

1717
_ToFloat64: TypeAlias = np.float16 | np.float32 | np.float64 | npc.integer | np.bool_
1818

scipy-stubs/integrate/_ivp/ivp.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class OdeResult(_RichResult[Any], Generic[_Inexact64T_co]):
7272
def prepare_events(events: _Events[_Inexact64T]) -> tuple[_Events[_Inexact64T], _Float1D, _Float1D]: ...
7373
def solve_event_equation(event: _FuncEvent[_Inexact64T], sol: _FuncSol[_Inexact64T], t_old: float, t: float) -> float: ...
7474
def handle_events(
75-
sol: DenseOutput,
75+
sol: DenseOutput[Any],
7676
events: Sequence[_FuncEvent[_Inexact64T]],
7777
active_events: onp.ArrayND[np.intp],
7878
event_count: onp.ArrayND[np.intp | np.float64],

scipy-stubs/integrate/_ivp/lsoda.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class LSODA(OdeSolver):
2626
**extraneous: Never,
2727
) -> None: ...
2828

29-
class LsodaDenseOutput(DenseOutput):
29+
class LsodaDenseOutput(DenseOutput[np.float64]):
3030
h: float
3131
yh: onp.Array1D[np.float64]
3232
p: onp.Array1D[np.intp]

scipy-stubs/integrate/_ivp/radau.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class Radau(OdeSolver):
6969
**extraneous: Never,
7070
) -> None: ...
7171

72-
class RadauDenseOutput(DenseOutput):
72+
class RadauDenseOutput(DenseOutput[np.float64]):
7373
order: int
7474
h: float
7575
Q: onp.ArrayND[np.float64]

scipy-stubs/integrate/_ivp/rk.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ class DOP853(RungeKutta[_SCT_fc], Generic[_SCT_fc]):
6060

6161
K_extended: onp.ArrayND[_SCT_fc]
6262

63-
class RkDenseOutput(DenseOutput, Generic[_SCT_fc]):
63+
class RkDenseOutput(DenseOutput[_SCT_fc], Generic[_SCT_fc]):
6464
h: float
6565
order: int
6666
Q: onp.ArrayND[_SCT_fc]
6767
y_old: onp.ArrayND[_SCT_fc]
6868

6969
def __init__(self, /, t_old: float, t: float, y_old: onp.ArrayND[_SCT_fc], Q: onp.ArrayND[_SCT_fc]) -> None: ...
7070

71-
class Dop853DenseOutput(DenseOutput, Generic[_SCT_fc]):
71+
class Dop853DenseOutput(DenseOutput[_SCT_fc], Generic[_SCT_fc]):
7272
h: float
7373
F: onp.ArrayND[_SCT_fc]
7474
y_old: onp.ArrayND[_SCT_fc]

0 commit comments

Comments
 (0)