Skip to content

Commit 1e6c429

Browse files
authored
sparse.linalg: Improved linear solver function annotations (#678)
2 parents 5de0731 + 552d405 commit 1e6c429

File tree

5 files changed

+324
-53
lines changed

5 files changed

+324
-53
lines changed

README.md

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -221,22 +221,23 @@ All generic type parameters are optional and can be omitted if not needed.
221221

222222
#### `scipy.sparse.linalg`
223223

224-
| generic type | |
225-
| --------------------------- | ---------------------------------------------------------------------------------------------------- |
226-
| `LaplacianNd[T: real]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LaplacianNd.html) |
227-
| `LinearOperator[T: scalar]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html) |
224+
| generic type | |
225+
| ----------------------------------- | ---------------------------------------------------------------------------------------------------- |
226+
| `LaplacianNd[T: real]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LaplacianNd.html) |
227+
| `LinearOperator[T: scalar]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html) |
228+
| `SuperLU[T: inexact]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.SuperLU.html) |
228229

229230
### `scipy.stats`
230231

231-
| generic type | |
232-
| ---------------------------------------------- | --------------------------------------------------------------------------------------------- |
233-
| `Covariance[T: real]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Covariance.html) |
234-
| `Uniform[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Uniform.html) |
235-
| `Normal[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Normal.html) |
236-
| `Binomial[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Binomial.html) |
237-
| `Mixture[T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Mixture.html) |
238-
| `rv_frozen[D: rv_generic, T: scalar or array]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rv_frozen.html) |
239-
| `multi_rv_frozen[D: rv_generic]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.multi_rv_frozen.html) |
232+
| generic type | |
233+
| ---------------------------------------------- | ---------------------------------------------------------------------------------------- |
234+
| `Covariance[T: real]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Covariance.html) |
235+
| `Uniform[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Uniform.html) |
236+
| `Normal[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Normal.html) |
237+
| `Binomial[S: (int, ...), T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Binomial.html) |
238+
| `Mixture[T: floating]` | [docs](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.Mixture.html) |
239+
| `rv_frozen[D: rv_generic, T: scalar or array]` | |
240+
| `multi_rv_frozen[D: rv_generic]` | |
240241

241242
## Contributing
242243

scipy-stubs/sparse/linalg/_dsolve/_superlu.pyi

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,50 @@
11
from collections.abc import Callable, Mapping
2-
from typing import Any, Literal, TypeAlias, final, overload
2+
from typing import Any, Final, Generic, Literal, TypeAlias, final, overload
3+
from typing_extensions import TypeVar
34

45
import numpy as np
56
import optype as op
67
import optype.numpy as onp
8+
import optype.numpy.compat as npc
79

810
from scipy.sparse import csc_array, csc_matrix, csr_matrix
911

12+
_InexactT_co = TypeVar("_InexactT_co", bound=np.float32 | np.float64 | np.complex64 | np.complex128, default=Any, covariant=True)
13+
1014
_Int1D: TypeAlias = onp.Array1D[np.int32]
1115
_Float1D: TypeAlias = onp.Array1D[np.float64]
1216
_Float2D: TypeAlias = onp.Array2D[np.float64]
1317
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
1418
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
1519
_Inexact2D: TypeAlias = onp.Array2D[np.float32 | np.float64 | np.complex64 | np.complex128]
1620

21+
_Real: TypeAlias = npc.integer | npc.floating
22+
1723
###
1824

1925
@final
20-
class SuperLU:
21-
shape: tuple[int, int]
22-
nnz: int
23-
perm_r: onp.Array1D[np.intp]
24-
perm_c: onp.Array1D[np.intp]
25-
L: csc_array[np.float64 | np.complex128]
26-
U: csc_array[np.float64 | np.complex128]
26+
class SuperLU(Generic[_InexactT_co]):
27+
shape: Final[tuple[int, int]]
28+
nnz: Final[int]
29+
perm_r: Final[onp.Array1D[np.intp]]
30+
perm_c: Final[onp.Array1D[np.intp]]
31+
L: csc_array[_InexactT_co] # readonly
32+
U: csc_array[_InexactT_co] # readonly
2733

2834
@overload
29-
def solve(self, /, rhs: onp.Array1D[np.integer[Any] | np.floating[Any]]) -> _Float1D: ...
35+
def solve(self, /, rhs: onp.Array1D[_Real]) -> _Float1D: ...
3036
@overload
31-
def solve(self, /, rhs: onp.Array1D[np.complexfloating[Any, Any]]) -> _Complex1D: ...
37+
def solve(self, /, rhs: onp.Array1D[npc.complexfloating]) -> _Complex1D: ...
3238
@overload
33-
def solve(self, /, rhs: onp.Array2D[np.integer[Any] | np.floating[Any]]) -> _Float2D: ...
39+
def solve(self, /, rhs: onp.Array2D[_Real]) -> _Float2D: ...
3440
@overload
35-
def solve(self, /, rhs: onp.Array2D[np.complexfloating[Any, Any]]) -> _Complex2D: ...
41+
def solve(self, /, rhs: onp.Array2D[npc.complexfloating]) -> _Complex2D: ...
3642
@overload
37-
def solve(self, /, rhs: onp.ArrayND[np.integer[Any] | np.floating[Any]]) -> _Float1D | _Float2D: ...
43+
def solve(self, /, rhs: onp.ArrayND[_Real]) -> onp.ArrayND[np.float64]: ...
3844
@overload
39-
def solve(self, /, rhs: onp.ArrayND[np.complexfloating[Any, Any]]) -> _Complex1D | _Complex2D: ...
45+
def solve(self, /, rhs: onp.ArrayND[npc.complexfloating]) -> onp.ArrayND[np.complex128]: ...
4046
@overload
41-
def solve(self, /, rhs: onp.ArrayND[np.number[Any]]) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...
47+
def solve(self, /, rhs: onp.ArrayND[npc.number]) -> onp.ArrayND[np.float64 | np.complex128]: ...
4248

4349
def gssv(
4450
N: op.CanIndex,
@@ -77,4 +83,4 @@ def gstrs(
7783
U_rowind: _Int1D,
7884
U_colptr: _Int1D,
7985
B: _Inexact2D,
80-
) -> tuple[_Float1D | _Complex1D | _Float2D | _Complex2D, int]: ...
86+
) -> tuple[onp.ArrayND[np.float64 | np.complex128], int]: ...

0 commit comments

Comments
 (0)