Skip to content

Commit a12d226

Browse files
committed
integrate: minor helper function improvements in common
1 parent e36d9dd commit a12d226

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed
Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections.abc import Callable, Sequence
2-
from typing import Final, Literal, TypeAlias, TypeVar, overload
1+
from collections.abc import Callable
2+
from typing import Any, Final, Generic, Literal, TypeAlias, overload
3+
from typing_extensions import TypeVar
34

45
import numpy as np
56
import optype as op
@@ -9,13 +10,10 @@ import optype.numpy.compat as npc
910
from .base import DenseOutput
1011
from scipy.sparse import csc_matrix
1112

12-
_SCT = TypeVar("_SCT", bound=np.generic)
13+
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
1314
_ToFloatT = TypeVar("_ToFloatT", bound=onp.ToFloat)
15+
_InterpT_co = TypeVar("_InterpT_co", bound=DenseOutput, default=DenseOutput, covariant=True)
1416

15-
_Side: TypeAlias = Literal["left", "right"]
16-
_Interpolants: TypeAlias = Sequence[DenseOutput]
17-
18-
_To1D: TypeAlias = Sequence[_SCT] | onp.CanArrayND[_SCT]
1917
_ToFloat64: TypeAlias = np.float16 | np.float32 | np.float64 | npc.integer | np.bool_
2018

2119
###
@@ -28,43 +26,41 @@ NUM_JAC_MIN_FACTOR: Final[float] = ...
2826
NUM_JAC_FACTOR_INCREASE: Final[float] = 10
2927
NUM_JAC_FACTOR_DECREASE: Final[float] = 0.1
3028

31-
class OdeSolution:
29+
class OdeSolution(Generic[_InterpT_co]):
30+
interpolants: list[_InterpT_co]
3231
ts: onp.Array1D[np.float64]
3332
ts_sorted: onp.Array1D[np.float64]
3433
t_min: np.float64
3534
t_max: np.float64
3635
ascending: bool
37-
side: _Side
36+
side: Literal["left", "right"]
3837
n_segments: int
39-
interpolants: _Interpolants
4038

41-
def __init__(self, /, ts: onp.ToFloat1D, interpolants: _Interpolants, alt_segment: op.CanBool = False) -> None: ...
39+
def __init__(self, /, ts: onp.ToFloat1D, interpolants: list[_InterpT_co], alt_segment: op.CanBool = False) -> None: ...
40+
41+
#
4242
@overload
4343
def __call__(self, /, t: float | _ToFloat64) -> onp.Array1D[np.float64]: ...
4444
@overload
45-
def __call__(self, /, t: np.complex64 | np.complex128) -> onp.Array1D[np.complex128]: ...
45+
def __call__(self, /, t: op.JustComplex | np.complex128 | np.complex64) -> onp.Array1D[np.complex128]: ...
4646
@overload
4747
def __call__(self, /, t: np.longdouble) -> onp.Array1D[np.longdouble]: ...
4848
@overload
4949
def __call__(self, /, t: np.clongdouble) -> onp.Array1D[np.clongdouble]: ...
5050
@overload
51-
def __call__(self, /, t: complex) -> onp.Array1D[np.float64 | np.complex128]: ...
52-
@overload
53-
def __call__(self, /, t: Sequence[float | _ToFloat64] | onp.CanArrayND[_ToFloat64]) -> onp.Array2D[np.float64]: ...
54-
@overload
55-
def __call__(self, /, t: _To1D[np.complex64 | np.complex128]) -> onp.Array2D[np.complex128]: ...
51+
def __call__(self, /, t: onp.ToArray1D[float, _ToFloat64]) -> onp.Array2D[np.float64]: ...
5652
@overload
57-
def __call__(self, /, t: _To1D[np.clongdouble]) -> onp.Array2D[np.clongdouble]: ...
53+
def __call__(self, /, t: onp.ToArray1D[op.JustComplex, np.complex128 | np.complex64]) -> onp.Array2D[np.complex128]: ...
5854
@overload
59-
def __call__(self, /, t: Sequence[complex]) -> onp.Array2D[np.float64 | np.complex128]: ...
55+
def __call__(self, /, t: onp.CanArrayND[np.clongdouble]) -> onp.Array2D[np.clongdouble]: ...
6056

61-
def validate_first_step(first_step: _ToFloatT, t0: onp.ToFloat, t_bound: onp.ToFloat) -> _ToFloatT: ...
62-
def validate_max_step(max_step: _ToFloatT) -> _ToFloatT: ...
63-
def warn_extraneous(extraneous: dict[str, object]) -> None: ...
57+
def validate_first_step(first_step: _ToFloatT, t0: onp.ToFloat, t_bound: onp.ToFloat) -> _ToFloatT: ... # undocumented
58+
def validate_max_step(max_step: _ToFloatT) -> _ToFloatT: ... # undocumented
59+
def warn_extraneous(extraneous: dict[str, Any]) -> None: ... # undocumented
6460
def validate_tol(
65-
rtol: onp.ArrayND[npc.floating], atol: onp.ArrayND[npc.floating], n: int
66-
) -> tuple[onp.Array1D[npc.floating], onp.Array1D[npc.floating]]: ...
67-
def norm(x: onp.ToFloatND) -> npc.floating: ...
61+
rtol: onp.ArrayND[_FloatingT], atol: onp.ArrayND[_FloatingT], n: int
62+
) -> tuple[onp.Array1D[_FloatingT], onp.Array1D[_FloatingT]]: ... # undocumented
63+
def norm(x: onp.ToFloatND) -> npc.floating: ... # undocumented
6864
def select_initial_step(
6965
fun: Callable[[np.float64, onp.Array1D[np.float64]], onp.Array1D[np.float64]],
7066
t0: float | np.float64,
@@ -76,7 +72,7 @@ def select_initial_step(
7672
order: float | np.float64,
7773
rtol: float | np.float64,
7874
atol: float | np.float64,
79-
) -> float | np.float64: ...
75+
) -> float: ... # undocumented
8076
def num_jac(
8177
fun: Callable[[np.float64, onp.Array1D[np.float64]], onp.Array1D[np.float64]],
8278
t: float | np.float64,
@@ -85,4 +81,4 @@ def num_jac(
8581
threshold: float | np.float64,
8682
factor: onp.ArrayND[np.float64] | None,
8783
sparsity: tuple[csc_matrix, onp.ArrayND[np.intp]] | None = None,
88-
) -> tuple[onp.Array2D[np.float64] | csc_matrix, onp.Array1D[np.float64]]: ...
84+
) -> tuple[onp.Array2D[np.float64] | csc_matrix[np.float64], onp.Array1D[np.float64]]: ... # undocumented

0 commit comments

Comments
 (0)