11from collections .abc import Callable , Sequence
2- from typing import Concatenate , Final , Generic , Literal , TypeAlias , overload , type_check_only
2+ from typing import Any , Final , Generic , Literal , TypeAlias , TypeVarTuple , overload , type_check_only
33from typing_extensions import TypeVar , TypedDict , Unpack
44
55import numpy as np
@@ -10,58 +10,63 @@ from .base import DenseOutput, OdeSolver
1010from .common import OdeSolution
1111from scipy ._lib ._util import _RichResult
1212from scipy .sparse import sparray , spmatrix
13+ from scipy .sparse ._base import _spbase
1314
14- _SCT_cf = TypeVar ("_SCT_cf" , bound = npc .inexact , default = np .float64 | np .complex128 )
15+ _Ts = TypeVarTuple ("_Ts" )
16+ _ScalarT = TypeVar ("_ScalarT" , bound = npc .number | np .bool )
17+ _Inexact64T = TypeVar ("_Inexact64T" , bound = np .float64 | np .complex128 )
18+ _Inexact64T_co = TypeVar ("_Inexact64T_co" , bound = np .float64 | np .complex128 , default = np .float64 | np .complex128 , covariant = True )
1519
16- _FuncSol : TypeAlias = Callable [[float ], onp .ArrayND [_SCT_cf ]]
17- _FuncEvent : TypeAlias = Callable [[float , onp .ArrayND [_SCT_cf ] ], float ]
18- _Events : TypeAlias = Sequence [_FuncEvent [_SCT_cf ] ]
20+ _FuncSol : TypeAlias = Callable [[np . float64 ], onp .ArrayND [_Inexact64T ]]
21+ _FuncEvent : TypeAlias = Callable [[np . float64 , onp .ArrayND [_Inexact64T ], * _Ts ], float ]
22+ _Events : TypeAlias = Sequence [_FuncEvent [_Inexact64T , * _Ts ]] | _FuncEvent [ _Inexact64T , * _Ts ]
1923
20- _Int1D : TypeAlias = onp .Array1D [np .intp ]
24+ _Int1D : TypeAlias = onp .Array1D [np .int_ ]
2125_Float1D : TypeAlias = onp .Array1D [np .float64 ]
26+ _Float2D : TypeAlias = onp .Array2D [np .float64 ]
27+ _Complex1D : TypeAlias = onp .Array1D [np .complex128 ]
28+ _Complex2D : TypeAlias = onp .Array2D [np .complex128 ]
2229
23- _ToJac : TypeAlias = onp .ToComplex2D | spmatrix | sparray
30+ _Sparse2D : TypeAlias = _spbase [_ScalarT , tuple [int , int ]] | sparray [_ScalarT , tuple [int , int ]] | spmatrix [_ScalarT ]
31+ _ToJac : TypeAlias = onp .ToArray2D [complex , npc .inexact ] | _Sparse2D [npc .inexact ]
2432
25- _IVPMethod : TypeAlias = Literal ["RK23" , "RK45" , "DOP853" , "Radau" , "BDF" , "LSODA" ]
33+ _IVPMethod : TypeAlias = Literal ["RK23" , "RK45" , "DOP853" , "Radau" , "BDF" , "LSODA" ] | type [ OdeSolver ]
2634
2735@type_check_only
28- class _SolverOptions (TypedDict , Generic [ _SCT_cf ], total = False ):
29- first_step : onp . ToFloat | None
30- max_step : onp . ToFloat
31- rtol : onp . ToFloat | onp .ToFloat1D
32- atol : onp . ToFloat | onp .ToFloat1D
33- jac : _ToJac | Callable [[float , onp .Array1D [ np . float64 ] ], _ToJac ] | None
34- jac_sparsity : onp .ToFloat2D | spmatrix | sparray | None
35- lband : onp . ToInt | None
36- uband : onp . ToInt | None
37- min_step : onp . ToFloat
36+ class _SolverOptions (TypedDict , total = False ):
37+ first_step : float | None
38+ max_step : float
39+ rtol : float | onp .ToFloat1D
40+ atol : float | onp .ToFloat1D
41+ jac : _ToJac | Callable [[np . float64 , onp .Array1D ], _ToJac ] | None
42+ jac_sparsity : onp .ToFloat2D | _Sparse2D [ npc . floating ] | None
43+ lband : int | None
44+ uband : int | None
45+ min_step : float
3846
3947###
4048
4149METHODS : Final [dict [str , type ]] = ...
4250MESSAGES : Final [dict [int , str ]] = ...
4351
44- class OdeResult (
45- _RichResult [int | str | onp .ArrayND [np .float64 | _SCT_cf ] | list [onp .ArrayND [np .float64 | _SCT_cf ]] | OdeSolution | None ],
46- Generic [_SCT_cf ],
47- ):
52+ class OdeResult (_RichResult [Any ], Generic [_Inexact64T_co ]):
4853 t : _Float1D
49- y : onp .Array2D [_SCT_cf ]
54+ y : onp .Array2D [_Inexact64T_co ]
5055 sol : OdeSolution | None
5156 t_events : list [_Float1D ] | None
52- y_events : list [onp .ArrayND [_SCT_cf ]] | None
57+ y_events : list [onp .ArrayND [_Inexact64T_co ]] | None
5358 nfev : int
5459 njev : int
5560 nlu : int
5661 status : Literal [- 1 , 0 , 1 ]
5762 message : str
5863 success : bool
5964
60- def prepare_events (events : _FuncEvent [ _SCT_cf ] | _Events [_SCT_cf ]) -> tuple [_Events [_SCT_cf ], _Float1D , _Float1D ]: ...
61- def solve_event_equation (event : _FuncEvent [_SCT_cf ], sol : _FuncSol [_SCT_cf ], t_old : float , t : float ) -> float : ...
65+ def prepare_events (events : _Events [_Inexact64T ]) -> tuple [_Events [_Inexact64T ], _Float1D , _Float1D ]: ...
66+ def solve_event_equation (event : _FuncEvent [_Inexact64T ], sol : _FuncSol [_Inexact64T ], t_old : float , t : float ) -> float : ...
6267def handle_events (
6368 sol : DenseOutput ,
64- events : Sequence [_FuncEvent [_SCT_cf ]],
69+ events : Sequence [_FuncEvent [_Inexact64T ]],
6570 active_events : onp .ArrayND [np .intp ],
6671 event_count : onp .ArrayND [np .intp | np .float64 ],
6772 max_events : onp .ArrayND [np .intp | np .float64 ],
@@ -71,30 +76,113 @@ def handle_events(
7176def find_active_events (g : onp .ToFloat1D , g_new : onp .ToFloat1D , direction : onp .ArrayND [np .float64 ]) -> _Int1D : ...
7277
7378#
74- @overload
79+ @overload # float, vectorized=False (default), args=None (default)
7580def solve_ivp (
76- fun : Callable [Concatenate [ float , onp . Array1D [ _SCT_cf ], ... ], onp .ArrayND [ _SCT_cf ] ],
77- t_span : Sequence [onp . ToFloat ],
78- y0 : onp .ToArray1D ,
79- method : _IVPMethod | type [ OdeSolver ] = "RK45" ,
81+ fun : Callable [[ np . float64 , _Float1D ], onp .ToFloat1D | float ],
82+ t_span : Sequence [float ],
83+ y0 : onp .ToFloat1D ,
84+ method : _IVPMethod = "RK45" ,
8085 t_eval : onp .ToFloat1D | None = None ,
8186 dense_output : bool = False ,
82- events : _Events [_SCT_cf ] | None = None ,
87+ events : _Events [np . float64 ] | None = None ,
8388 vectorized : onp .ToFalse = False ,
84- args : tuple [ object , ...] | None = None ,
89+ args : None = None ,
8590 ** options : Unpack [_SolverOptions ],
86- ) -> OdeResult [_SCT_cf ]: ...
87- @overload
91+ ) -> OdeResult [np . float64 ]: ...
92+ @overload # float, vectorized=False (default), args=<given>
8893def solve_ivp (
89- fun : Callable [Concatenate [ _Float1D , onp . Array2D [ _SCT_cf ], ... ], onp .ArrayND [ _SCT_cf ] ],
90- t_span : Sequence [onp . ToFloat ],
91- y0 : onp .ToArray1D ,
92- method : _IVPMethod | type [ OdeSolver ] = "RK45" ,
94+ fun : Callable [[ np . float64 , _Float1D , * _Ts ], onp .ToFloat1D | float ],
95+ t_span : Sequence [float ],
96+ y0 : onp .ToFloat1D ,
97+ method : _IVPMethod = "RK45" ,
9398 t_eval : onp .ToFloat1D | None = None ,
9499 dense_output : bool = False ,
95- events : _Events [_SCT_cf ] | None = None ,
100+ events : _Events [np .float64 ] | None = None ,
101+ vectorized : onp .ToFalse = False ,
102+ * ,
103+ args : tuple [* _Ts ],
104+ ** options : Unpack [_SolverOptions ],
105+ ) -> OdeResult [np .float64 ]: ...
106+ @overload # float, vectorized=True, args=None (default)
107+ def solve_ivp (
108+ fun : Callable [[_Float1D , _Float2D ], onp .ToFloat2D ],
109+ t_span : Sequence [float ],
110+ y0 : onp .ToFloat1D ,
111+ method : _IVPMethod = "RK45" ,
112+ t_eval : onp .ToFloat1D | None = None ,
113+ dense_output : bool = False ,
114+ events : _Events [np .float64 ] | None = None ,
115+ * ,
116+ vectorized : onp .ToTrue ,
117+ args : None = None ,
118+ ** options : Unpack [_SolverOptions ],
119+ ) -> OdeResult [np .float64 ]: ...
120+ @overload # float, vectorized=True, args=<given>
121+ def solve_ivp (
122+ fun : Callable [[_Float1D , _Float2D , * _Ts ], onp .ToFloat2D ],
123+ t_span : Sequence [float ],
124+ y0 : onp .ToFloat1D ,
125+ method : _IVPMethod = "RK45" ,
126+ t_eval : onp .ToFloat1D | None = None ,
127+ dense_output : bool = False ,
128+ events : _Events [np .float64 ] | None = None ,
129+ * ,
130+ vectorized : onp .ToTrue ,
131+ args : tuple [* _Ts ],
132+ ** options : Unpack [_SolverOptions ],
133+ ) -> OdeResult [np .float64 ]: ...
134+ @overload # complex, vectorized=False (default), args=None (default)
135+ def solve_ivp (
136+ fun : Callable [[np .float64 , _Complex1D ], onp .ToComplex1D | complex ],
137+ t_span : Sequence [float ],
138+ y0 : onp .ToComplex1D ,
139+ method : _IVPMethod = "RK45" ,
140+ t_eval : onp .ToFloat1D | None = None ,
141+ dense_output : bool = False ,
142+ events : _Events [np .complex128 ] | None = None ,
143+ vectorized : onp .ToFalse = False ,
144+ args : None = None ,
145+ ** options : Unpack [_SolverOptions ],
146+ ) -> OdeResult [np .complex128 ]: ...
147+ @overload # complex, vectorized=False (default), args=<given>
148+ def solve_ivp (
149+ fun : Callable [[np .float64 , _Complex1D , * _Ts ], onp .ToComplex1D | complex ],
150+ t_span : Sequence [float ],
151+ y0 : onp .ToComplex1D ,
152+ method : _IVPMethod = "RK45" ,
153+ t_eval : onp .ToFloat1D | None = None ,
154+ dense_output : bool = False ,
155+ events : _Events [np .complex128 ] | None = None ,
156+ vectorized : onp .ToFalse = False ,
157+ * ,
158+ args : tuple [* _Ts ],
159+ ** options : Unpack [_SolverOptions ],
160+ ) -> OdeResult [np .complex128 ]: ...
161+ @overload # complex, vectorized=True, args=None (default)
162+ def solve_ivp (
163+ fun : Callable [[_Float1D , _Complex2D ], onp .ToComplex2D ],
164+ t_span : Sequence [float ],
165+ y0 : onp .ToComplex1D ,
166+ method : _IVPMethod = "RK45" ,
167+ t_eval : onp .ToFloat1D | None = None ,
168+ dense_output : bool = False ,
169+ events : _Events [np .complex128 ] | None = None ,
170+ * ,
171+ vectorized : onp .ToTrue ,
172+ args : None = None ,
173+ ** options : Unpack [_SolverOptions ],
174+ ) -> OdeResult [np .complex128 ]: ...
175+ @overload # complex, vectorized=True, args=<given>
176+ def solve_ivp (
177+ fun : Callable [[_Float1D , _Complex2D , * _Ts ], onp .ToComplex2D ],
178+ t_span : Sequence [float ],
179+ y0 : onp .ToComplex1D ,
180+ method : _IVPMethod = "RK45" ,
181+ t_eval : onp .ToFloat1D | None = None ,
182+ dense_output : bool = False ,
183+ events : _Events [np .complex128 ] | None = None ,
96184 * ,
97185 vectorized : onp .ToTrue ,
98- args : tuple [object , ...] | None = None ,
186+ args : tuple [* _Ts ] ,
99187 ** options : Unpack [_SolverOptions ],
100- ) -> OdeResult [_SCT_cf ]: ...
188+ ) -> OdeResult [np . complex128 ]: ...
0 commit comments