|
1 | 1 | from collections.abc import Callable, Mapping
|
2 |
| -from typing import Any, Literal, TypeAlias, final, overload |
| 2 | +from typing import Any, Final, Generic, Literal, TypeAlias, final, overload |
| 3 | +from typing_extensions import TypeVar |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 | import optype as op
|
6 | 7 | import optype.numpy as onp
|
| 8 | +import optype.numpy.compat as npc |
7 | 9 |
|
8 | 10 | from scipy.sparse import csc_array, csc_matrix, csr_matrix
|
9 | 11 |
|
| 12 | +_InexactT_co = TypeVar("_InexactT_co", bound=np.float32 | np.float64 | np.complex64 | np.complex128, default=Any, covariant=True) |
| 13 | + |
10 | 14 | _Int1D: TypeAlias = onp.Array1D[np.int32]
|
11 | 15 | _Float1D: TypeAlias = onp.Array1D[np.float64]
|
12 | 16 | _Float2D: TypeAlias = onp.Array2D[np.float64]
|
13 | 17 | _Complex1D: TypeAlias = onp.Array1D[np.complex128]
|
14 | 18 | _Complex2D: TypeAlias = onp.Array2D[np.complex128]
|
15 | 19 | _Inexact2D: TypeAlias = onp.Array2D[np.float32 | np.float64 | np.complex64 | np.complex128]
|
16 | 20 |
|
| 21 | +_Real: TypeAlias = npc.integer | npc.floating |
| 22 | + |
17 | 23 | ###
|
18 | 24 |
|
19 | 25 | @final
|
20 |
| -class SuperLU: |
21 |
| - shape: tuple[int, int] |
22 |
| - nnz: int |
23 |
| - perm_r: onp.Array1D[np.intp] |
24 |
| - perm_c: onp.Array1D[np.intp] |
25 |
| - L: csc_array[np.float64 | np.complex128] |
26 |
| - U: csc_array[np.float64 | np.complex128] |
| 26 | +class SuperLU(Generic[_InexactT_co]): |
| 27 | + shape: Final[tuple[int, int]] |
| 28 | + nnz: Final[int] |
| 29 | + perm_r: Final[onp.Array1D[np.intp]] |
| 30 | + perm_c: Final[onp.Array1D[np.intp]] |
| 31 | + L: csc_array[_InexactT_co] # readonly |
| 32 | + U: csc_array[_InexactT_co] # readonly |
27 | 33 |
|
28 | 34 | @overload
|
29 |
| - def solve(self, /, rhs: onp.Array1D[np.integer[Any] | np.floating[Any]]) -> _Float1D: ... |
| 35 | + def solve(self, /, rhs: onp.Array1D[_Real]) -> _Float1D: ... |
30 | 36 | @overload
|
31 |
| - def solve(self, /, rhs: onp.Array1D[np.complexfloating[Any, Any]]) -> _Complex1D: ... |
| 37 | + def solve(self, /, rhs: onp.Array1D[npc.complexfloating]) -> _Complex1D: ... |
32 | 38 | @overload
|
33 |
| - def solve(self, /, rhs: onp.Array2D[np.integer[Any] | np.floating[Any]]) -> _Float2D: ... |
| 39 | + def solve(self, /, rhs: onp.Array2D[_Real]) -> _Float2D: ... |
34 | 40 | @overload
|
35 |
| - def solve(self, /, rhs: onp.Array2D[np.complexfloating[Any, Any]]) -> _Complex2D: ... |
| 41 | + def solve(self, /, rhs: onp.Array2D[npc.complexfloating]) -> _Complex2D: ... |
36 | 42 | @overload
|
37 |
| - def solve(self, /, rhs: onp.ArrayND[np.integer[Any] | np.floating[Any]]) -> _Float1D | _Float2D: ... |
| 43 | + def solve(self, /, rhs: onp.ArrayND[_Real]) -> onp.ArrayND[np.float64]: ... |
38 | 44 | @overload
|
39 |
| - def solve(self, /, rhs: onp.ArrayND[np.complexfloating[Any, Any]]) -> _Complex1D | _Complex2D: ... |
| 45 | + def solve(self, /, rhs: onp.ArrayND[npc.complexfloating]) -> onp.ArrayND[np.complex128]: ... |
40 | 46 | @overload
|
41 |
| - def solve(self, /, rhs: onp.ArrayND[np.number[Any]]) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ... |
| 47 | + def solve(self, /, rhs: onp.ArrayND[npc.number]) -> onp.ArrayND[np.float64 | np.complex128]: ... |
42 | 48 |
|
43 | 49 | def gssv(
|
44 | 50 | N: op.CanIndex,
|
@@ -77,4 +83,4 @@ def gstrs(
|
77 | 83 | U_rowind: _Int1D,
|
78 | 84 | U_colptr: _Int1D,
|
79 | 85 | B: _Inexact2D,
|
80 |
| -) -> tuple[_Float1D | _Complex1D | _Float2D | _Complex2D, int]: ... |
| 86 | +) -> tuple[onp.ArrayND[np.float64 | np.complex128], int]: ... |
0 commit comments