11from collections .abc import Callable
2- from typing import Final , Generic , Never , TypeAlias
2+ from typing import Any , Final , Generic , Never , TypeAlias , overload
33from typing_extensions import TypeVar
44
55import numpy as np
66import optype .numpy as onp
77import optype .numpy .compat as npc
88
99from .base import DenseOutput , OdeSolver
10- from scipy .sparse import sparray , spmatrix
10+ from scipy .sparse import csc_matrix , sparray , spmatrix
1111
1212###
1313
14- _SCT_co = TypeVar ("_SCT_co" , covariant = True , bound = npc .inexact , default = np .float64 | np .complex128 )
14+ _NumberT = TypeVar ("_NumberT" , bound = npc .number )
15+ _InexactT = TypeVar ("_InexactT" , bound = np .float64 | np .complex128 , default = np .float64 | Any )
1516
1617_LU : TypeAlias = tuple [onp .ArrayND [npc .inexact ], onp .ArrayND [npc .integer ]]
1718_FuncLU : TypeAlias = Callable [[onp .ArrayND [np .float64 ]], _LU ] | Callable [[onp .ArrayND [np .complex128 ]], _LU ]
1819_FuncSolveLU : TypeAlias = Callable [[_LU , onp .ArrayND ], onp .ArrayND [npc .inexact ]]
1920
20- _ToJac : TypeAlias = onp .ToComplex2D | spmatrix | sparray
21+ _Sparse2D : TypeAlias = spmatrix [_NumberT ] | sparray [_NumberT , tuple [int , int ]]
22+ _ArrayOrCSC : TypeAlias = onp .Array2D [_NumberT ] | csc_matrix [_NumberT ]
23+
24+ _ToJacReal : TypeAlias = onp .ToFloat2D | _Sparse2D [npc .floating | npc .integer ]
25+ _ToJacComplex : TypeAlias = onp .ToComplex2D | _Sparse2D [npc .number ]
2126
2227###
2328
@@ -26,40 +31,63 @@ NEWTON_MAXITER: Final = 4
2631MIN_FACTOR : Final = 0.2
2732MAX_FACTOR : Final = 10
2833
29- class BDF (OdeSolver , Generic [_SCT_co ]):
34+ class BDF (OdeSolver [ _InexactT ] , Generic [_InexactT ]):
3035 max_step : float
3136 h_abs : float
3237 h_abs_old : float | None
3338 error_norm_old : None
3439 newton_tol : float
35- jac_factor : onp .ArrayND [np .float64 ] | None # 1d
3640
37- LU : _LU
41+ jac_factor : onp .Array1D [np .float64 ] | None
42+ jac : Callable [[float , onp .ArrayND [_InexactT ]], _ArrayOrCSC [_InexactT ]] | None
43+
44+ J : _ArrayOrCSC [_InexactT ]
45+ I : _ArrayOrCSC [_InexactT ]
46+ D : onp .Array2D [_InexactT ]
47+
48+ LU : _LU | None
3849 lu : _FuncLU
3950 solve_lu : _FuncSolveLU
4051
41- I : onp .ArrayND [_SCT_co ]
42- error_const : onp .ArrayND [np .float64 ]
43- gamma : onp .ArrayND [np .float64 ]
44- alpha : onp .ArrayND [np .float64 ]
45- D : onp .ArrayND [np .float64 ]
46- order : int
52+ gamma : onp .Array1D [np .float64 ]
53+ alpha : onp .Array1D [np .float64 ]
54+ error_const : onp .Array1D [np .float64 ]
55+
56+ order : int | np .intp
4757 n_equal_steps : int
4858
59+ @overload
60+ def __init__ (
61+ self : OdeSolver [np .float64 ],
62+ / ,
63+ fun : Callable [[float , onp .ArrayND [np .float64 ]], onp .ToFloatND ],
64+ t0 : float ,
65+ y0 : onp .ToFloatND ,
66+ t_bound : float ,
67+ max_step : float = ..., # = np.inf
68+ rtol : float = 1e-3 ,
69+ atol : float = 1e-6 ,
70+ jac : _ToJacReal | Callable [[float , onp .ArrayND [np .float64 ]], _ToJacReal ] | None = None ,
71+ jac_sparsity : _ToJacReal | None = None ,
72+ vectorized : bool = False ,
73+ first_step : float | None = None ,
74+ ** extraneous : Never ,
75+ ) -> None : ...
76+ @overload
4977 def __init__ (
50- self ,
78+ self : OdeSolver [ np . complex128 ] ,
5179 / ,
52- fun : Callable [[float , onp .Array1D [ _SCT_co ]], onp .ToComplex1D ],
53- t0 : onp . ToFloat ,
54- y0 : onp .Array1D [ _SCT_co ] | onp . ToComplexND ,
55- t_bound : onp . ToFloat ,
56- max_step : onp . ToFloat = ...,
57- rtol : onp . ToFloat = 0.001 ,
58- atol : onp . ToFloat = 1e-06 ,
59- jac : _ToJac | Callable [[float , onp .ArrayND [_SCT_co ]], _ToJac ] | None = None ,
60- jac_sparsity : _ToJac | None = None ,
80+ fun : Callable [[float , onp .ArrayND [ np . complex128 ]], onp .ToComplexND ],
81+ t0 : float ,
82+ y0 : onp .ToJustComplexND ,
83+ t_bound : float ,
84+ max_step : float = ..., # = np.inf
85+ rtol : float = 1e-3 ,
86+ atol : float = 1e-6 ,
87+ jac : _ToJacComplex | Callable [[float , onp .ArrayND [np . complex128 ]], _ToJacComplex ] | None = None ,
88+ jac_sparsity : _ToJacComplex | None = None ,
6189 vectorized : bool = False ,
62- first_step : onp . ToFloat | None = None ,
90+ first_step : float | None = None ,
6391 ** extraneous : Never ,
6492 ) -> None : ...
6593
@@ -73,13 +101,13 @@ class BdfDenseOutput(DenseOutput[np.float64]):
73101def compute_R (order : int , factor : float ) -> onp .ArrayND [np .float64 ]: ...
74102def change_D (D : onp .ArrayND [np .float64 ], order : int , factor : float ) -> None : ...
75103def solve_bdf_system (
76- fun : Callable [[float , onp .ArrayND [_SCT_co ]], onp .ToComplex1D ],
104+ fun : Callable [[float , onp .ArrayND [_InexactT ]], onp .ToComplex1D ],
77105 t_new : onp .ToFloat ,
78- y_predict : onp .ArrayND [_SCT_co ],
106+ y_predict : onp .ArrayND [_InexactT ],
79107 c : float ,
80108 psi : onp .ArrayND [np .float64 ],
81109 LU : _FuncLU ,
82110 solve_lu : _FuncSolveLU ,
83111 scale : onp .ArrayND [np .float64 ],
84112 tol : float ,
85- ) -> tuple [bool , int , onp .ArrayND [_SCT_co ], onp .ArrayND [_SCT_co ]]: ...
113+ ) -> tuple [bool , int , onp .ArrayND [_InexactT ], onp .ArrayND [_InexactT ]]: ...
0 commit comments