Skip to content

Commit 62d93ba

Browse files
authored
🐛sparse.linalg: allow passing arrays to expm_multiply (#708)
2 parents 5f252d9 + 22c451a commit 62d93ba

File tree

4 files changed

+163
-13
lines changed

4 files changed

+163
-13
lines changed
Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,115 @@
1-
from typing import Any, TypeVar, overload
1+
from _typeshed import Incomplete
2+
from typing import Any, Never, TypeAlias, TypeVar, overload
23

34
import numpy as np
45
import optype as op
56
import optype.numpy as onp
7+
import optype.numpy.compat as npc
68

7-
from scipy.sparse._typing import Numeric
8-
from scipy.sparse.linalg._interface import LinearOperator
9+
from ._interface import LinearOperator
10+
from scipy.sparse._base import _spbase, sparray
911

1012
__all__ = ["expm_multiply"]
1113

12-
_SCT = TypeVar("_SCT", bound=Numeric)
14+
_ScalarT = TypeVar("_ScalarT", bound=npc.number | np.bool_)
15+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
16+
_ShapeT = TypeVar("_ShapeT", bound=tuple[Any, ...])
17+
18+
_ToLinearOperator: TypeAlias = LinearOperator[_ScalarT] | _spbase[_ScalarT, tuple[int, int]] | onp.ArrayND[_ScalarT]
19+
_SparseOrDense: TypeAlias = sparray[_ScalarT, _ShapeT] | onp.ArrayND[_ScalarT, _ShapeT]
20+
21+
_AsFloat64: TypeAlias = np.float64 | npc.integer | np.bool_
22+
_ToFloat64: TypeAlias = _AsFloat64 | np.float32 | np.float16
1323

1424
###
1525

26+
@overload # workaround for mypy's and pyright's typing spec non-compliance regarding overloads
27+
def expm_multiply(
28+
A: _ToLinearOperator[_AsFloat64],
29+
B: _SparseOrDense[_ToFloat64, tuple[Never] | tuple[Never, Never]],
30+
start: onp.ToFloat | None = None,
31+
stop: onp.ToFloat | None = None,
32+
num: op.CanIndex | None = None,
33+
endpoint: bool | None = None,
34+
traceA: onp.ToComplex | None = None,
35+
) -> onp.ArrayND[np.float64]: ...
36+
@overload
37+
def expm_multiply(
38+
A: _ToLinearOperator[_InexactT],
39+
B: _SparseOrDense[_InexactT | npc.integer | np.bool_, tuple[Never] | tuple[Never, Never]],
40+
start: onp.ToFloat | None = None,
41+
stop: onp.ToFloat | None = None,
42+
num: op.CanIndex | None = None,
43+
endpoint: bool | None = None,
44+
traceA: onp.ToComplex | None = None,
45+
) -> onp.ArrayND[_InexactT]: ...
1646
@overload # 1-d
1747
def expm_multiply(
18-
A: LinearOperator[_SCT],
19-
B: onp.Array1D[_SCT | np.integer[Any] | np.float16 | np.bool_],
48+
A: _ToLinearOperator[_AsFloat64],
49+
B: _SparseOrDense[_ToFloat64, tuple[int]],
50+
start: onp.ToFloat | None = None,
51+
stop: onp.ToFloat | None = None,
52+
num: op.CanIndex | None = None,
53+
endpoint: bool | None = None,
54+
traceA: onp.ToComplex | None = None,
55+
) -> onp.Array1D[np.float64]: ...
56+
@overload
57+
def expm_multiply(
58+
A: _ToLinearOperator[_InexactT],
59+
B: _SparseOrDense[_InexactT | npc.integer | np.bool_, tuple[int]],
2060
start: onp.ToFloat | None = None,
2161
stop: onp.ToFloat | None = None,
2262
num: op.CanIndex | None = None,
2363
endpoint: bool | None = None,
2464
traceA: onp.ToComplex | None = None,
25-
) -> onp.Array1D[_SCT]: ...
65+
) -> onp.Array1D[_InexactT]: ...
2666
@overload # 2-d
2767
def expm_multiply(
28-
A: LinearOperator[_SCT],
29-
B: onp.Array2D[_SCT | np.integer[Any] | np.float16 | np.bool_],
68+
A: _ToLinearOperator[_AsFloat64],
69+
B: _SparseOrDense[_ToFloat64, tuple[int, int]],
70+
start: onp.ToFloat | None = None,
71+
stop: onp.ToFloat | None = None,
72+
num: op.CanIndex | None = None,
73+
endpoint: bool | None = None,
74+
traceA: onp.ToComplex | None = None,
75+
) -> onp.Array2D[np.float64]: ...
76+
@overload
77+
def expm_multiply(
78+
A: _ToLinearOperator[_InexactT],
79+
B: _SparseOrDense[_InexactT | npc.integer | np.bool_, tuple[int, int]],
3080
start: onp.ToFloat | None = None,
3181
stop: onp.ToFloat | None = None,
3282
num: op.CanIndex | None = None,
3383
endpoint: bool | None = None,
3484
traceA: onp.ToComplex | None = None,
35-
) -> onp.Array2D[_SCT]: ...
85+
) -> onp.Array2D[_InexactT]: ...
3686
@overload # 1-d or 2-d
3787
def expm_multiply(
38-
A: LinearOperator[_SCT],
39-
B: onp.ArrayND[_SCT | np.float16 | np.integer[Any] | np.bool_],
88+
A: _ToLinearOperator[_AsFloat64],
89+
B: _SparseOrDense[_ToFloat64, tuple[Any, ...]],
90+
start: onp.ToFloat | None = None,
91+
stop: onp.ToFloat | None = None,
92+
num: op.CanIndex | None = None,
93+
endpoint: bool | None = None,
94+
traceA: onp.ToComplex | None = None,
95+
) -> onp.ArrayND[np.float64]: ...
96+
@overload
97+
def expm_multiply(
98+
A: _ToLinearOperator[_InexactT],
99+
B: _SparseOrDense[_InexactT | npc.integer | np.bool_, tuple[Any, ...]],
100+
start: onp.ToFloat | None = None,
101+
stop: onp.ToFloat | None = None,
102+
num: op.CanIndex | None = None,
103+
endpoint: bool | None = None,
104+
traceA: onp.ToComplex | None = None,
105+
) -> onp.ArrayND[_InexactT]: ...
106+
@overload # fallback
107+
def expm_multiply(
108+
A: _ToLinearOperator[npc.number],
109+
B: _SparseOrDense[npc.number, tuple[Any, ...]],
40110
start: onp.ToFloat | None = None,
41111
stop: onp.ToFloat | None = None,
42112
num: op.CanIndex | None = None,
43113
endpoint: bool | None = None,
44114
traceA: onp.ToComplex | None = None,
45-
) -> onp.Array1D[_SCT] | onp.Array2D[_SCT]: ...
115+
) -> onp.ArrayND[Incomplete]: ...
File renamed without changes.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Any, assert_type
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from scipy.sparse import coo_array
7+
from scipy.sparse.linalg import LinearOperator, expm_multiply
8+
9+
_dense_i8_1d: np.ndarray[tuple[int], np.dtype[np.int8]]
10+
_dense_i8_2d: np.ndarray[tuple[int, int], np.dtype[np.int8]]
11+
_dense_i8_nd: npt.NDArray[np.int8]
12+
13+
_dense_f32_1d: np.ndarray[tuple[int], np.dtype[np.float32]]
14+
_dense_f32_2d: np.ndarray[tuple[int, int], np.dtype[np.float32]]
15+
_dense_f32_nd: npt.NDArray[np.float32]
16+
17+
_sparse_i8_1d: coo_array[np.int8, tuple[int]]
18+
_sparse_i8_2d: coo_array[np.int8, tuple[int, int]]
19+
_sparse_i8_nd: coo_array[np.int8]
20+
21+
_sparse_f32_1d: coo_array[np.float32, tuple[int]]
22+
_sparse_f32_2d: coo_array[np.float32, tuple[int, int]]
23+
_sparse_f32_nd: coo_array[np.float32]
24+
25+
_linop_i8: LinearOperator[np.int8]
26+
_linop_f32: LinearOperator[np.float32]
27+
28+
#
29+
30+
assert_type(expm_multiply(_dense_i8_2d, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
31+
assert_type(expm_multiply(_dense_i8_2d, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
32+
assert_type(expm_multiply(_dense_i8_2d, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
33+
34+
assert_type(expm_multiply(_dense_i8_nd, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
35+
assert_type(expm_multiply(_dense_i8_nd, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
36+
assert_type(expm_multiply(_dense_i8_nd, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
37+
38+
assert_type(expm_multiply(_sparse_i8_2d, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
39+
assert_type(expm_multiply(_sparse_i8_2d, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
40+
assert_type(expm_multiply(_sparse_i8_2d, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
41+
42+
assert_type(expm_multiply(_linop_i8, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
43+
assert_type(expm_multiply(_linop_i8, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
44+
assert_type(expm_multiply(_linop_i8, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
45+
46+
#
47+
48+
assert_type(expm_multiply(_dense_i8_2d, _dense_i8_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
49+
assert_type(expm_multiply(_dense_i8_2d, _dense_i8_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
50+
assert_type(expm_multiply(_dense_i8_2d, _dense_i8_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
51+
52+
assert_type(expm_multiply(_dense_i8_nd, _dense_i8_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
53+
assert_type(expm_multiply(_dense_i8_nd, _dense_i8_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
54+
assert_type(expm_multiply(_dense_i8_nd, _dense_i8_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
55+
56+
assert_type(expm_multiply(_sparse_i8_2d, _dense_i8_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
57+
assert_type(expm_multiply(_sparse_i8_2d, _dense_i8_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
58+
assert_type(expm_multiply(_sparse_i8_2d, _dense_i8_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
59+
60+
assert_type(expm_multiply(_linop_i8, _dense_i8_1d), np.ndarray[tuple[int], np.dtype[np.float64]])
61+
assert_type(expm_multiply(_linop_i8, _dense_i8_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]])
62+
assert_type(expm_multiply(_linop_i8, _dense_i8_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float64]])
63+
64+
#
65+
66+
assert_type(expm_multiply(_dense_f32_2d, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
67+
assert_type(expm_multiply(_dense_f32_2d, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
68+
assert_type(expm_multiply(_dense_f32_2d, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float32]])
69+
70+
assert_type(expm_multiply(_dense_f32_nd, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
71+
assert_type(expm_multiply(_dense_f32_nd, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
72+
assert_type(expm_multiply(_dense_f32_nd, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float32]])
73+
74+
assert_type(expm_multiply(_sparse_f32_2d, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
75+
assert_type(expm_multiply(_sparse_f32_2d, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
76+
assert_type(expm_multiply(_sparse_f32_2d, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float32]])
77+
78+
assert_type(expm_multiply(_linop_f32, _dense_f32_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
79+
assert_type(expm_multiply(_linop_f32, _dense_f32_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
80+
assert_type(expm_multiply(_linop_f32, _dense_f32_nd), np.ndarray[tuple[Any, ...], np.dtype[np.float32]])

0 commit comments

Comments
 (0)