Skip to content

Commit a453999

Browse files
authored
🐛 integrate: several BDF fixes (#882)
2 parents 99b15dd + 4d83e55 commit a453999

File tree

2 files changed

+56
-28
lines changed

2 files changed

+56
-28
lines changed

scipy-stubs/integrate/_ivp/base.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import numpy as np
66
import optype.numpy as onp
77
import optype.numpy.compat as npc
88

9-
_ScalarT = TypeVar("_ScalarT", bound=np.float64 | np.complex128, default=np.float64)
9+
_ScalarT = TypeVar("_ScalarT", bound=np.float64 | np.complex128, default=np.float64 | Any)
1010
_ScalarT_co = TypeVar("_ScalarT_co", bound=npc.inexact, default=np.float64 | Any, covariant=True)
1111

1212
_ToFunReal: TypeAlias = Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND]

scipy-stubs/integrate/_ivp/bdf.pyi

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
from collections.abc import Callable
2-
from typing import Final, Generic, Never, TypeAlias
2+
from typing import Any, Final, Generic, Never, TypeAlias, overload
33
from typing_extensions import TypeVar
44

55
import numpy as np
66
import optype.numpy as onp
77
import optype.numpy.compat as npc
88

99
from .base import DenseOutput, OdeSolver
10-
from scipy.sparse import sparray, spmatrix
10+
from scipy.sparse import csc_matrix, sparray, spmatrix
1111

1212
###
1313

14-
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=npc.inexact, default=np.float64 | np.complex128)
14+
_NumberT = TypeVar("_NumberT", bound=npc.number)
15+
_InexactT = TypeVar("_InexactT", bound=np.float64 | np.complex128, default=np.float64 | Any)
1516

1617
_LU: TypeAlias = tuple[onp.ArrayND[npc.inexact], onp.ArrayND[npc.integer]]
1718
_FuncLU: TypeAlias = Callable[[onp.ArrayND[np.float64]], _LU] | Callable[[onp.ArrayND[np.complex128]], _LU]
1819
_FuncSolveLU: TypeAlias = Callable[[_LU, onp.ArrayND], onp.ArrayND[npc.inexact]]
1920

20-
_ToJac: TypeAlias = onp.ToComplex2D | spmatrix | sparray
21+
_Sparse2D: TypeAlias = spmatrix[_NumberT] | sparray[_NumberT, tuple[int, int]]
22+
_ArrayOrCSC: TypeAlias = onp.Array2D[_NumberT] | csc_matrix[_NumberT]
23+
24+
_ToJacReal: TypeAlias = onp.ToFloat2D | _Sparse2D[npc.floating | npc.integer]
25+
_ToJacComplex: TypeAlias = onp.ToComplex2D | _Sparse2D[npc.number]
2126

2227
###
2328

@@ -26,40 +31,63 @@ NEWTON_MAXITER: Final = 4
2631
MIN_FACTOR: Final = 0.2
2732
MAX_FACTOR: Final = 10
2833

29-
class BDF(OdeSolver, Generic[_SCT_co]):
34+
class BDF(OdeSolver[_InexactT], Generic[_InexactT]):
3035
max_step: float
3136
h_abs: float
3237
h_abs_old: float | None
3338
error_norm_old: None
3439
newton_tol: float
35-
jac_factor: onp.ArrayND[np.float64] | None # 1d
3640

37-
LU: _LU
41+
jac_factor: onp.Array1D[np.float64] | None
42+
jac: Callable[[float, onp.ArrayND[_InexactT]], _ArrayOrCSC[_InexactT]] | None
43+
44+
J: _ArrayOrCSC[_InexactT]
45+
I: _ArrayOrCSC[_InexactT]
46+
D: onp.Array2D[_InexactT]
47+
48+
LU: _LU | None
3849
lu: _FuncLU
3950
solve_lu: _FuncSolveLU
4051

41-
I: onp.ArrayND[_SCT_co]
42-
error_const: onp.ArrayND[np.float64]
43-
gamma: onp.ArrayND[np.float64]
44-
alpha: onp.ArrayND[np.float64]
45-
D: onp.ArrayND[np.float64]
46-
order: int
52+
gamma: onp.Array1D[np.float64]
53+
alpha: onp.Array1D[np.float64]
54+
error_const: onp.Array1D[np.float64]
55+
56+
order: int | np.intp
4757
n_equal_steps: int
4858

59+
@overload
60+
def __init__(
61+
self: OdeSolver[np.float64],
62+
/,
63+
fun: Callable[[float, onp.ArrayND[np.float64]], onp.ToFloatND],
64+
t0: float,
65+
y0: onp.ToFloatND,
66+
t_bound: float,
67+
max_step: float = ..., # = np.inf
68+
rtol: float = 1e-3,
69+
atol: float = 1e-6,
70+
jac: _ToJacReal | Callable[[float, onp.ArrayND[np.float64]], _ToJacReal] | None = None,
71+
jac_sparsity: _ToJacReal | None = None,
72+
vectorized: bool = False,
73+
first_step: float | None = None,
74+
**extraneous: Never,
75+
) -> None: ...
76+
@overload
4977
def __init__(
50-
self,
78+
self: OdeSolver[np.complex128],
5179
/,
52-
fun: Callable[[float, onp.Array1D[_SCT_co]], onp.ToComplex1D],
53-
t0: onp.ToFloat,
54-
y0: onp.Array1D[_SCT_co] | onp.ToComplexND,
55-
t_bound: onp.ToFloat,
56-
max_step: onp.ToFloat = ...,
57-
rtol: onp.ToFloat = 0.001,
58-
atol: onp.ToFloat = 1e-06,
59-
jac: _ToJac | Callable[[float, onp.ArrayND[_SCT_co]], _ToJac] | None = None,
60-
jac_sparsity: _ToJac | None = None,
80+
fun: Callable[[float, onp.ArrayND[np.complex128]], onp.ToComplexND],
81+
t0: float,
82+
y0: onp.ToJustComplexND,
83+
t_bound: float,
84+
max_step: float = ..., # = np.inf
85+
rtol: float = 1e-3,
86+
atol: float = 1e-6,
87+
jac: _ToJacComplex | Callable[[float, onp.ArrayND[np.complex128]], _ToJacComplex] | None = None,
88+
jac_sparsity: _ToJacComplex | None = None,
6189
vectorized: bool = False,
62-
first_step: onp.ToFloat | None = None,
90+
first_step: float | None = None,
6391
**extraneous: Never,
6492
) -> None: ...
6593

@@ -73,13 +101,13 @@ class BdfDenseOutput(DenseOutput[np.float64]):
73101
def compute_R(order: int, factor: float) -> onp.ArrayND[np.float64]: ...
74102
def change_D(D: onp.ArrayND[np.float64], order: int, factor: float) -> None: ...
75103
def solve_bdf_system(
76-
fun: Callable[[float, onp.ArrayND[_SCT_co]], onp.ToComplex1D],
104+
fun: Callable[[float, onp.ArrayND[_InexactT]], onp.ToComplex1D],
77105
t_new: onp.ToFloat,
78-
y_predict: onp.ArrayND[_SCT_co],
106+
y_predict: onp.ArrayND[_InexactT],
79107
c: float,
80108
psi: onp.ArrayND[np.float64],
81109
LU: _FuncLU,
82110
solve_lu: _FuncSolveLU,
83111
scale: onp.ArrayND[np.float64],
84112
tol: float,
85-
) -> tuple[bool, int, onp.ArrayND[_SCT_co], onp.ArrayND[_SCT_co]]: ...
113+
) -> tuple[bool, int, onp.ArrayND[_InexactT], onp.ArrayND[_InexactT]]: ...

0 commit comments

Comments
 (0)