|
1 | 1 | from collections.abc import Callable |
2 | | -from typing import Any, ClassVar, Final, Generic, Literal, TypeVar, overload |
| 2 | +from typing import Any, ClassVar, Final, Generic, Literal, TypeAlias, overload |
| 3 | +from typing_extensions import TypeVar |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import optype.numpy as onp |
6 | 7 | import optype.numpy.compat as npc |
7 | 8 |
|
8 | | -_VT = TypeVar("_VT", bound=onp.ArrayND[npc.inexact], default=onp.ArrayND[Any]) |
| 9 | +_ScalarT = TypeVar("_ScalarT", bound=np.float64 | np.complex128, default=np.float64) |
| 10 | +_ScalarT_co = TypeVar("_ScalarT_co", bound=npc.inexact, default=np.float64 | Any, covariant=True) |
9 | 11 |
|
10 | | -class OdeSolver: |
| 12 | +_ToFunReal: TypeAlias = Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND] |
| 13 | +_ToFunComplex: TypeAlias = Callable[[float, onp.ArrayND[np.complex128]], onp.ToComplexND] |
| 14 | + |
| 15 | +@overload |
| 16 | +def check_arguments( |
| 17 | + fun: _ToFunReal, y0: onp.ToFloatND, support_complex: bool |
| 18 | +) -> Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]]: ... |
| 19 | +@overload |
| 20 | +def check_arguments( |
| 21 | + fun: _ToFunComplex, y0: onp.ToJustComplexND, support_complex: Literal[True] |
| 22 | +) -> Callable[[float, onp.ArrayND[np.complex128]], onp.ArrayND[np.complex128]]: ... |
| 23 | + |
| 24 | +class OdeSolver(Generic[_ScalarT]): |
11 | 25 | TOO_SMALL_STEP: ClassVar[str] = ... |
12 | 26 |
|
13 | 27 | t: float |
14 | 28 | t_old: float |
15 | 29 | t_bound: float |
| 30 | + y: onp.ArrayND[_ScalarT] |
16 | 31 | vectorized: bool |
17 | | - fun: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]] |
18 | | - fun_single: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]] |
19 | | - fun_vectorized: Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]] |
20 | | - direction: float |
21 | | - n: int |
| 32 | + fun: Callable[[float, onp.ArrayND[_ScalarT]], onp.ArrayND[_ScalarT]] |
| 33 | + fun_single: Callable[[float, onp.Array1D[_ScalarT]], onp.Array1D[_ScalarT]] |
| 34 | + fun_vectorized: Callable[[float, onp.Array2D[_ScalarT]], onp.Array2D[_ScalarT]] |
| 35 | + direction: np.float64 |
22 | 36 | status: Literal["running", "finished", "failed"] |
| 37 | + n: int |
23 | 38 | nfev: int |
24 | 39 | njev: int |
25 | 40 | nlu: int |
26 | 41 |
|
27 | 42 | @overload |
28 | 43 | def __init__( |
29 | | - self, |
| 44 | + self: OdeSolver[np.float64], |
30 | 45 | /, |
31 | | - fun: Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND], |
32 | | - t0: onp.ToFloatND, |
| 46 | + fun: _ToFunReal, |
| 47 | + t0: float, |
33 | 48 | y0: onp.ToFloatND, |
34 | | - t_bound: onp.ToFloat, |
| 49 | + t_bound: float, |
35 | 50 | vectorized: bool, |
36 | 51 | support_complex: onp.ToBool = False, |
37 | 52 | ) -> None: ... |
38 | 53 | @overload |
39 | 54 | def __init__( |
40 | | - self, |
| 55 | + self: OdeSolver[np.complex128], |
41 | 56 | /, |
42 | | - fun: Callable[[float, onp.ArrayND[np.float64 | np.complex128]], onp.ToComplexND], |
43 | | - t0: onp.ToFloat, |
44 | | - y0: onp.ToComplexND, |
| 57 | + fun: _ToFunComplex, |
| 58 | + t0: float, |
| 59 | + y0: onp.ToJustComplexND, |
45 | 60 | t_bound: onp.ToFloat, |
46 | 61 | vectorized: bool, |
47 | 62 | support_complex: onp.ToTrue, |
48 | 63 | ) -> None: ... |
49 | 64 | @property |
50 | | - def step_size(self, /) -> float | None: ... |
| 65 | + def step_size(self, /) -> np.float64 | None: ... |
51 | 66 | def step(self, /) -> str | None: ... |
52 | | - def dense_output(self, /) -> ConstantDenseOutput: ... |
| 67 | + def dense_output(self, /) -> ConstantDenseOutput[_ScalarT]: ... |
53 | 68 |
|
54 | | -class DenseOutput: |
| 69 | +class DenseOutput(Generic[_ScalarT_co]): |
55 | 70 | t_old: Final[float] |
56 | 71 | t: Final[float] |
57 | 72 | t_min: Final[float] |
58 | 73 | t_max: Final[float] |
59 | 74 |
|
60 | | - def __init__(self, /, t_old: onp.ToFloat, t: onp.ToFloat) -> None: ... |
| 75 | + def __init__(self, /, t_old: float, t: float) -> None: ... |
| 76 | + |
| 77 | + # |
61 | 78 | @overload |
62 | | - def __call__(self, /, t: onp.ToFloat) -> onp.Array1D[npc.inexact]: ... |
| 79 | + def __call__(self, /, t: onp.ToFloat) -> onp.Array1D[_ScalarT_co]: ... |
63 | 80 | @overload |
64 | | - def __call__(self, /, t: onp.ToFloatND) -> onp.ArrayND[npc.inexact]: ... |
| 81 | + def __call__(self, /, t: onp.ToFloatND) -> onp.ArrayND[_ScalarT_co]: ... |
65 | 82 |
|
66 | | -class ConstantDenseOutput(DenseOutput, Generic[_VT]): |
67 | | - value: _VT |
68 | | - def __init__(self, /, t_old: onp.ToFloat, t: onp.ToFloat, value: _VT) -> None: ... |
69 | | - |
70 | | -def check_arguments( |
71 | | - fun: Callable[[float, onp.ArrayND[np.float64]], onp.ToComplexND], y0: onp.ToComplexND, support_complex: bool |
72 | | -) -> ( |
73 | | - Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.float64]] |
74 | | - | Callable[[float, onp.ArrayND[np.float64]], onp.ArrayND[np.complex128]] |
75 | | -): ... |
| 83 | +class ConstantDenseOutput(DenseOutput[_ScalarT_co], Generic[_ScalarT_co]): |
| 84 | + value: onp.ArrayND[_ScalarT_co] |
| 85 | + def __init__(self, /, t_old: float, t: float, value: onp.ArrayND[_ScalarT_co]) -> None: ... |
0 commit comments