1- from collections .abc import Callable , Sequence
2- from typing import Final , Literal , TypeAlias , TypeVar , overload
1+ from collections .abc import Callable
2+ from typing import Any , Final , Generic , Literal , TypeAlias , overload
3+ from typing_extensions import TypeVar
34
45import numpy as np
56import optype as op
@@ -9,13 +10,10 @@ import optype.numpy.compat as npc
910from .base import DenseOutput
1011from scipy .sparse import csc_matrix
1112
12- _SCT = TypeVar ("_SCT " , bound = np . generic )
13+ _FloatingT = TypeVar ("_FloatingT " , bound = npc . floating )
1314_ToFloatT = TypeVar ("_ToFloatT" , bound = onp .ToFloat )
15+ _InterpT_co = TypeVar ("_InterpT_co" , bound = DenseOutput , default = DenseOutput , covariant = True )
1416
15- _Side : TypeAlias = Literal ["left" , "right" ]
16- _Interpolants : TypeAlias = Sequence [DenseOutput ]
17-
18- _To1D : TypeAlias = Sequence [_SCT ] | onp .CanArrayND [_SCT ]
1917_ToFloat64 : TypeAlias = np .float16 | np .float32 | np .float64 | npc .integer | np .bool_
2018
2119###
@@ -28,43 +26,41 @@ NUM_JAC_MIN_FACTOR: Final[float] = ...
2826NUM_JAC_FACTOR_INCREASE : Final [float ] = 10
2927NUM_JAC_FACTOR_DECREASE : Final [float ] = 0.1
3028
31- class OdeSolution :
29+ class OdeSolution (Generic [_InterpT_co ]):
30+ interpolants : list [_InterpT_co ]
3231 ts : onp .Array1D [np .float64 ]
3332 ts_sorted : onp .Array1D [np .float64 ]
3433 t_min : np .float64
3534 t_max : np .float64
3635 ascending : bool
37- side : _Side
36+ side : Literal [ "left" , "right" ]
3837 n_segments : int
39- interpolants : _Interpolants
4038
41- def __init__ (self , / , ts : onp .ToFloat1D , interpolants : _Interpolants , alt_segment : op .CanBool = False ) -> None : ...
39+ def __init__ (self , / , ts : onp .ToFloat1D , interpolants : list [_InterpT_co ], alt_segment : op .CanBool = False ) -> None : ...
40+
41+ #
4242 @overload
4343 def __call__ (self , / , t : float | _ToFloat64 ) -> onp .Array1D [np .float64 ]: ...
4444 @overload
45- def __call__ (self , / , t : np . complex64 | np .complex128 ) -> onp .Array1D [np .complex128 ]: ...
45+ def __call__ (self , / , t : op . JustComplex | np .complex128 | np . complex64 ) -> onp .Array1D [np .complex128 ]: ...
4646 @overload
4747 def __call__ (self , / , t : np .longdouble ) -> onp .Array1D [np .longdouble ]: ...
4848 @overload
4949 def __call__ (self , / , t : np .clongdouble ) -> onp .Array1D [np .clongdouble ]: ...
5050 @overload
51- def __call__ (self , / , t : complex ) -> onp .Array1D [np .float64 | np .complex128 ]: ...
52- @overload
53- def __call__ (self , / , t : Sequence [float | _ToFloat64 ] | onp .CanArrayND [_ToFloat64 ]) -> onp .Array2D [np .float64 ]: ...
54- @overload
55- def __call__ (self , / , t : _To1D [np .complex64 | np .complex128 ]) -> onp .Array2D [np .complex128 ]: ...
51+ def __call__ (self , / , t : onp .ToArray1D [float , _ToFloat64 ]) -> onp .Array2D [np .float64 ]: ...
5652 @overload
57- def __call__ (self , / , t : _To1D [ np .clongdouble ]) -> onp .Array2D [np .clongdouble ]: ...
53+ def __call__ (self , / , t : onp . ToArray1D [ op . JustComplex , np .complex128 | np . complex64 ]) -> onp .Array2D [np .complex128 ]: ...
5854 @overload
59- def __call__ (self , / , t : Sequence [ complex ]) -> onp .Array2D [np .float64 | np . complex128 ]: ...
55+ def __call__ (self , / , t : onp . CanArrayND [ np . clongdouble ]) -> onp .Array2D [np .clongdouble ]: ...
6056
61- def validate_first_step (first_step : _ToFloatT , t0 : onp .ToFloat , t_bound : onp .ToFloat ) -> _ToFloatT : ...
62- def validate_max_step (max_step : _ToFloatT ) -> _ToFloatT : ...
63- def warn_extraneous (extraneous : dict [str , object ]) -> None : ...
57+ def validate_first_step (first_step : _ToFloatT , t0 : onp .ToFloat , t_bound : onp .ToFloat ) -> _ToFloatT : ... # undocumented
58+ def validate_max_step (max_step : _ToFloatT ) -> _ToFloatT : ... # undocumented
59+ def warn_extraneous (extraneous : dict [str , Any ]) -> None : ... # undocumented
6460def validate_tol (
65- rtol : onp .ArrayND [npc . floating ], atol : onp .ArrayND [npc . floating ], n : int
66- ) -> tuple [onp .Array1D [npc . floating ], onp .Array1D [npc . floating ]]: ...
67- def norm (x : onp .ToFloatND ) -> npc .floating : ...
61+ rtol : onp .ArrayND [_FloatingT ], atol : onp .ArrayND [_FloatingT ], n : int
62+ ) -> tuple [onp .Array1D [_FloatingT ], onp .Array1D [_FloatingT ]]: ... # undocumented
63+ def norm (x : onp .ToFloatND ) -> npc .floating : ... # undocumented
6864def select_initial_step (
6965 fun : Callable [[np .float64 , onp .Array1D [np .float64 ]], onp .Array1D [np .float64 ]],
7066 t0 : float | np .float64 ,
@@ -76,7 +72,7 @@ def select_initial_step(
7672 order : float | np .float64 ,
7773 rtol : float | np .float64 ,
7874 atol : float | np .float64 ,
79- ) -> float | np . float64 : ...
75+ ) -> float : ... # undocumented
8076def num_jac (
8177 fun : Callable [[np .float64 , onp .Array1D [np .float64 ]], onp .Array1D [np .float64 ]],
8278 t : float | np .float64 ,
@@ -85,4 +81,4 @@ def num_jac(
8581 threshold : float | np .float64 ,
8682 factor : onp .ArrayND [np .float64 ] | None ,
8783 sparsity : tuple [csc_matrix , onp .ArrayND [np .intp ]] | None = None ,
88- ) -> tuple [onp .Array2D [np .float64 ] | csc_matrix , onp .Array1D [np .float64 ]]: ...
84+ ) -> tuple [onp .Array2D [np .float64 ] | csc_matrix [ np . float64 ] , onp .Array1D [np .float64 ]]: ... # undocumented
0 commit comments