11from collections .abc import Callable
2- from typing import Any , ClassVar , Final , Generic , Literal , Protocol , Self , TypeAlias , TypedDict , type_check_only
2+ from typing import Any , ClassVar , Final , Generic , Literal , Self , TypeAlias , TypedDict , overload , type_check_only
33from typing_extensions import TypeVar , TypeVarTuple , Unpack , override
44
55import numpy as np
@@ -8,8 +8,11 @@ import optype.numpy.compat as npc
88
99__all__ = ["complex_ode" , "ode" ]
1010
11- _SCT_co = TypeVar ("_SCT_co" , covariant = True , bound = npc .inexact , default = np .float64 | np .complex128 )
12- _Ts = TypeVarTuple ("_Ts" , default = Unpack [tuple [()]])
11+ _Ts = TypeVarTuple ("_Ts" , default = Unpack [tuple [Any , ...]])
12+ _Inexact64T_co = TypeVar ("_Inexact64T_co" , bound = npc .inexact , default = np .float64 | np .complex128 , covariant = True )
13+
14+ _IntegratorReal : TypeAlias = Literal ["vode" , "dopri5" , "dop853" , "lsoda" ]
15+ _IntegratorComplex : TypeAlias = Literal ["vode" , "zvode" ]
1316
1417@type_check_only
1518class _IntegratorParams (TypedDict , total = False ):
@@ -34,56 +37,69 @@ class _IntegratorParams(TypedDict, total=False):
3437 beta : float
3538 verbosity : int
3639
37- @type_check_only
38- class _ODEFuncF (Protocol [* _Ts ]):
39- def __call__ (self , t : float , y : float | onp .ArrayND [np .float64 ], / , * args : * _Ts ) -> float | onp .ArrayND [npc .floating ]: ...
40+ ###
4041
41- @type_check_only
42- class _ODEFuncC (Protocol [* _Ts ]):
43- def __call__ (
44- self , t : float , y : complex | onp .ArrayND [np .complex128 ], / , * args : * _Ts
45- ) -> complex | onp .ArrayND [npc .complexfloating ]: ...
42+ class IntegratorConcurrencyError (RuntimeError ):
43+ def __init__ (self , / , name : str ) -> None : ...
4644
47- _SolOutFunc : TypeAlias = Callable [[float , onp .Array1D [npc .inexact ]], Literal [0 , - 1 ]]
45+ class ode (Generic [_Inexact64T_co , * _Ts ]):
46+ f : Callable [[float , onp .Array1D [_Inexact64T_co ], * _Ts ], complex | onp .ToComplex1D ]
47+ jac : Callable [[float , onp .Array1D [_Inexact64T_co ], * _Ts ], complex | onp .ToComplex2D ] | None
48+ f_params : tuple [* _Ts ]
49+ jac_params : tuple [* _Ts ]
50+ stiff : Literal [0 , 1 ]
51+ t : float
4852
49- ###
53+ def __init__ (
54+ self ,
55+ / ,
56+ f : Callable [[float , onp .Array1D [_Inexact64T_co ], * _Ts ], complex | onp .ToComplex1D ],
57+ jac : Callable [[float , onp .Array1D [_Inexact64T_co ], * _Ts ], complex | onp .ToComplex2D ] | None = None ,
58+ ) -> None : ...
5059
51- class ode (Generic [* _Ts ]):
52- stiff : int
53- f : _ODEFuncF [* _Ts ]
54- f_params : tuple [()] | tuple [* _Ts ]
55- jac : _ODEFuncF [* _Ts ] | None
56- jac_params : tuple [()] | tuple [* _Ts ]
57- t : float
58- def __init__ (self , / , f : _ODEFuncF [* _Ts ], jac : _ODEFuncF [* _Ts ] | None = None ) -> None : ...
60+ #
5961 @property
60- def y (self , / ) -> float : ...
61- def integrate (self , / , t : float , step : bool = False , relax : bool = False ) -> float : ...
62- def set_initial_value (self , / , y : onp .ToComplex | onp .ToComplexND , t : float = 0.0 ) -> Self : ...
63- def set_integrator (self , / , name : str , ** integrator_params : Unpack [_IntegratorParams ]) -> Self : ...
62+ def y (self , / ) -> onp .Array1D [_Inexact64T_co ]: ...
63+
64+ #
65+ @overload
66+ def set_initial_value (
67+ self : ode [np .float64 , * _Ts ], / , y : float | onp .ToFloat1D , t : float = 0.0
68+ ) -> ode [_Inexact64T_co , * _Ts ]: ...
69+ @overload
70+ def set_initial_value (
71+ self : ode [np .complex128 , * _Ts ], / , y : complex | onp .ToComplex1D , t : float = 0.0
72+ ) -> ode [_Inexact64T_co , * _Ts ]: ...
73+
74+ #
75+ @overload
76+ def set_integrator (
77+ self : ode [np .float64 , * _Ts ], / , name : _IntegratorReal , ** integrator_params : Unpack [_IntegratorParams ]
78+ ) -> ode [_Inexact64T_co , * _Ts ]: ...
79+ @overload
80+ def set_integrator (
81+ self : ode [np .complex128 , * _Ts ], / , name : _IntegratorComplex , ** integrator_params : Unpack [_IntegratorParams ]
82+ ) -> ode [_Inexact64T_co , * _Ts ]: ...
83+
84+ #
85+ def integrate (self , / , t : float , step : bool = False , relax : bool = False ) -> onp .Array1D [_Inexact64T_co ]: ...
86+ def successful (self , / ) -> bool : ...
87+ def get_return_code (self , / ) -> Literal [- 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 ]: ...
6488 def set_f_params (self , / , * args : * _Ts ) -> Self : ...
6589 def set_jac_params (self , / , * args : * _Ts ) -> Self : ...
66- def set_solout (self , / , solout : _SolOutFunc ) -> None : ...
67- def get_return_code (self , / ) -> Literal [- 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 ]: ...
68- def successful (self , / ) -> bool : ...
90+ def set_solout (self , / , solout : Callable [[float , onp .Array1D [_Inexact64T_co ]], Literal [- 1 , 0 ] | None ]) -> None : ...
6991
70- class complex_ode (ode [* _Ts ], Generic [* _Ts ]):
71- cf : _ODEFuncC [ * _Ts ]
72- cjac : _ODEFuncC [ * _Ts ] | None
92+ class complex_ode (ode [np . complex128 , * _Ts ], Generic [* _Ts ]):
93+ cf : Callable [[ float , onp . Array1D [ np . complex128 ], * _Ts ], complex | onp . ToComplex1D ]
94+ cjac : Callable [[ float , onp . Array1D [ np . complex128 ], * _Ts ], complex | onp . ToComplex2D ] | None
7395 tmp : onp .Array1D [np .float64 ]
74- def __init__ (self , / , f : _ODEFuncC [* _Ts ], jac : _ODEFuncC [* _Ts ] | None = None ) -> None : ...
75- @property
96+
7697 @override
77- def y (self , / ) -> complex : ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
98+ def set_integrator (self , / , name : _IntegratorReal , ** integrator_params : Unpack [ _IntegratorParams ] ) -> Self : ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
7899 @override
79- def integrate (self , / , t : float , step : bool = False , relax : bool = False ) -> complex : ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
100+ def set_initial_value (self , / , y : complex | onp . ToComplex1D , t : float = 0.0 ) -> Self : ...
80101
81- def find_integrator (name : str ) -> type [IntegratorBase ] | None : ...
82-
83- class IntegratorConcurrencyError (RuntimeError ):
84- def __init__ (self , / , name : str ) -> None : ...
85-
86- class IntegratorBase (Generic [_SCT_co ]):
102+ class IntegratorBase (Generic [_Inexact64T_co ]):
87103 runner : ClassVar [Callable [..., tuple [Any , ...]] | None ] # fortran function or unavailable
88104 supports_run_relax : ClassVar [Literal [0 , 1 ] | None ] = None
89105 supports_step : ClassVar [Literal [0 , 1 ] | None ] = None
@@ -101,38 +117,38 @@ class IntegratorBase(Generic[_SCT_co]):
101117 def run (
102118 self ,
103119 / ,
104- f : Callable [..., _SCT_co ],
105- jac : Callable [..., onp .ArrayND [_SCT_co ]] | None ,
120+ f : Callable [..., _Inexact64T_co ],
121+ jac : Callable [..., onp .ArrayND [_Inexact64T_co ]] | None ,
106122 y0 : complex ,
107123 t0 : float ,
108124 t1 : float ,
109125 f_params : tuple [object , ...],
110126 jac_params : tuple [object , ...],
111- ) -> tuple [_SCT_co , float ]: ...
127+ ) -> tuple [_Inexact64T_co , float ]: ...
112128 def step (
113129 self ,
114130 / ,
115- f : Callable [..., _SCT_co ],
116- jac : Callable [..., onp .ArrayND [_SCT_co ]],
131+ f : Callable [..., _Inexact64T_co ],
132+ jac : Callable [..., onp .ArrayND [_Inexact64T_co ]],
117133 y0 : complex ,
118134 t0 : float ,
119135 t1 : float ,
120136 f_params : tuple [object , ...],
121137 jac_params : tuple [object , ...],
122- ) -> tuple [_SCT_co , float ]: ...
138+ ) -> tuple [_Inexact64T_co , float ]: ...
123139 def run_relax (
124140 self ,
125141 / ,
126- f : Callable [..., _SCT_co ],
127- jac : Callable [..., onp .ArrayND [_SCT_co ]],
142+ f : Callable [..., _Inexact64T_co ],
143+ jac : Callable [..., onp .ArrayND [_Inexact64T_co ]],
128144 y0 : complex ,
129145 t0 : float ,
130146 t1 : float ,
131147 f_params : tuple [object , ...],
132148 jac_params : tuple [object , ...],
133- ) -> tuple [_SCT_co , float ]: ...
149+ ) -> tuple [_Inexact64T_co , float ]: ...
134150
135- class vode (IntegratorBase [_SCT_co ], Generic [_SCT_co ]):
151+ class vode (IntegratorBase [_Inexact64T_co ], Generic [_Inexact64T_co ]):
136152 messages : ClassVar [dict [int , str ]] = ...
137153
138154 active_global_handle : int
@@ -210,14 +226,16 @@ class dopri5(IntegratorBase[np.float64]):
210226 method : None = None , # unused
211227 verbosity : int = - 1 ,
212228 ) -> None : ...
213- def set_solout (self , / , solout : _SolOutFunc | None , complex : bool = False ) -> None : ...
229+ def set_solout (
230+ self , / , solout : Callable [[float , onp .Array1D [np .float64 ]], Literal [0 , - 1 ]] | None , complex : bool = False
231+ ) -> None : ...
214232 def _solout (
215233 self ,
216234 / ,
217235 nr : int , # unused
218236 xold : object , # unused
219237 x : float ,
220- y : onp .Array1D [npc . floating ],
238+ y : onp .Array1D [np . float64 ],
221239 nd : int , # unused
222240 icomp : int , # unused
223241 con : object , # unused
@@ -280,3 +298,16 @@ class lsoda(IntegratorBase[np.float64]):
280298 max_order_s : int = 5 ,
281299 method : None = None , # ignored
282300 ) -> None : ...
301+
302+ @overload
303+ def find_integrator (name : Literal ["vode" ]) -> type [vode ]: ...
304+ @overload
305+ def find_integrator (name : Literal ["zvode" ]) -> type [zvode ]: ...
306+ @overload
307+ def find_integrator (name : Literal ["dopri5" ]) -> type [dopri5 ]: ...
308+ @overload
309+ def find_integrator (name : Literal ["dop853" ]) -> type [dop853 ]: ...
310+ @overload
311+ def find_integrator (name : Literal ["lsoda" ]) -> type [lsoda ]: ...
312+ @overload
313+ def find_integrator (name : str ) -> type [IntegratorBase ] | None : ...
0 commit comments