Skip to content

Commit 1c968aa

Browse files
authored
linalg: support sparse arrays in clarkson_woodruff_transform (#711)
2 parents 2fa8bd2 + b14cf5a commit 1c968aa

File tree

4 files changed

+122
-15
lines changed

4 files changed

+122
-15
lines changed

scipy-stubs/linalg/_sketches.pyi

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,114 @@
1-
from typing import overload
1+
from typing import Any, Never, TypeAlias, TypeVar, overload
22

33
import numpy as np
4+
import optype as op
45
import optype.numpy as onp
6+
import optype.numpy.compat as npc
57

6-
from scipy.sparse import csc_matrix
8+
from scipy.sparse import csc_matrix, sparray, spmatrix
9+
from scipy.sparse._base import _spbase
710

811
__all__ = ["clarkson_woodruff_transform"]
912

13+
_ScalarT = TypeVar("_ScalarT", bound=npc.number | np.bool_)
14+
15+
_ToInt: TypeAlias = int | npc.integer
16+
_ToSparse: TypeAlias = _spbase[_ScalarT] | sparray[_ScalarT] | spmatrix[_ScalarT]
17+
1018
###
1119

12-
def cwt_matrix(n_rows: onp.ToInt, n_columns: onp.ToInt, rng: onp.random.ToRNG | None = None) -> csc_matrix[np.int_]: ...
20+
def cwt_matrix(n_rows: _ToInt, n_columns: _ToInt, rng: onp.random.ToRNG | None = None) -> csc_matrix[np.int_]: ...
1321

1422
#
1523
@overload
1624
def clarkson_woodruff_transform(
17-
input_matrix: onp.ToIntND,
18-
sketch_size: onp.ToInt,
25+
input_matrix: _ToSparse[Never],
26+
sketch_size: _ToInt,
27+
rng: onp.random.ToRNG | None = None,
28+
*,
29+
seed: onp.random.ToRNG | None = None,
30+
) -> csc_matrix[Any]: ...
31+
@overload
32+
def clarkson_woodruff_transform(
33+
input_matrix: _ToSparse[npc.integer | np.bool_],
34+
sketch_size: _ToInt,
35+
rng: onp.random.ToRNG | None = None,
36+
*,
37+
seed: onp.random.ToRNG | None = None,
38+
) -> csc_matrix[np.int_]: ...
39+
@overload
40+
def clarkson_woodruff_transform(
41+
input_matrix: _ToSparse[np.float32 | np.float64],
42+
sketch_size: _ToInt,
43+
rng: onp.random.ToRNG | None = None,
44+
*,
45+
seed: onp.random.ToRNG | None = None,
46+
) -> csc_matrix[np.float64]: ...
47+
@overload
48+
def clarkson_woodruff_transform(
49+
input_matrix: _ToSparse[np.longdouble],
50+
sketch_size: _ToInt,
1951
rng: onp.random.ToRNG | None = None,
2052
*,
2153
seed: onp.random.ToRNG | None = None,
54+
) -> csc_matrix[np.longdouble]: ...
55+
@overload
56+
def clarkson_woodruff_transform(
57+
input_matrix: _ToSparse[np.complex64 | np.complex128],
58+
sketch_size: _ToInt,
59+
rng: onp.random.ToRNG | None = None,
60+
*,
61+
seed: onp.random.ToRNG | None = None,
62+
) -> csc_matrix[np.complex128]: ...
63+
@overload
64+
def clarkson_woodruff_transform(
65+
input_matrix: _ToSparse[np.clongdouble],
66+
sketch_size: _ToInt,
67+
rng: onp.random.ToRNG | None = None,
68+
*,
69+
seed: onp.random.ToRNG | None = None,
70+
) -> csc_matrix[np.clongdouble]: ...
71+
@overload
72+
def clarkson_woodruff_transform(
73+
input_matrix: onp.CanArrayND[Never],
74+
sketch_size: _ToInt,
75+
rng: onp.random.ToRNG | None = None,
76+
*,
77+
seed: onp.random.ToRNG | None = None,
78+
) -> onp.ArrayND[Any]: ...
79+
@overload
80+
def clarkson_woodruff_transform(
81+
input_matrix: onp.ToIntND, sketch_size: _ToInt, rng: onp.random.ToRNG | None = None, *, seed: onp.random.ToRNG | None = None
2282
) -> onp.ArrayND[np.int_]: ...
2383
@overload
2484
def clarkson_woodruff_transform(
25-
input_matrix: onp.ToJustFloat64_ND,
26-
sketch_size: onp.ToInt,
85+
input_matrix: onp.ToArrayND[op.JustFloat, np.float16 | np.float32 | np.float64],
86+
sketch_size: _ToInt,
2787
rng: onp.random.ToRNG | None = None,
2888
*,
2989
seed: onp.random.ToRNG | None = None,
3090
) -> onp.ArrayND[np.float64]: ...
3191
@overload
3292
def clarkson_woodruff_transform(
33-
input_matrix: onp.ToJustFloatND,
34-
sketch_size: onp.ToInt,
93+
input_matrix: onp.ToArrayND[np.longdouble, np.longdouble],
94+
sketch_size: _ToInt,
3595
rng: onp.random.ToRNG | None = None,
3696
*,
3797
seed: onp.random.ToRNG | None = None,
38-
) -> onp.ArrayND[np.float64 | np.longdouble]: ...
98+
) -> onp.ArrayND[np.longdouble]: ...
3999
@overload
40100
def clarkson_woodruff_transform(
41-
input_matrix: onp.ToJustComplex128_ND,
42-
sketch_size: onp.ToInt,
101+
input_matrix: onp.ToArrayND[op.JustComplex, np.complex64 | np.complex128],
102+
sketch_size: _ToInt,
43103
rng: onp.random.ToRNG | None = None,
44104
*,
45105
seed: onp.random.ToRNG | None = None,
46106
) -> onp.ArrayND[np.complex128]: ...
47107
@overload
48108
def clarkson_woodruff_transform(
49-
input_matrix: onp.ToJustComplexND,
50-
sketch_size: onp.ToInt,
109+
input_matrix: onp.ToArrayND[np.clongdouble, np.clongdouble],
110+
sketch_size: _ToInt,
51111
rng: onp.random.ToRNG | None = None,
52112
*,
53113
seed: onp.random.ToRNG | None = None,
54-
) -> onp.ArrayND[np.complex128 | np.clongdouble]: ...
114+
) -> onp.ArrayND[np.clongdouble]: ...
File renamed without changes.

tests/linalg/test__sketches.pyi

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Any, assert_type
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from scipy.linalg import clarkson_woodruff_transform
7+
from scipy.sparse import csc_matrix, sparray, spmatrix
8+
9+
like_bool_2d: list[list[bool]]
10+
like_i64_2d: list[list[int]]
11+
like_f64_2d: list[list[float]]
12+
like_c128_2d: list[list[complex]]
13+
14+
arr_any: npt.NDArray[Any]
15+
arr_i8: npt.NDArray[np.int8]
16+
arr_f32: npt.NDArray[np.float32]
17+
arr_f80: npt.NDArray[np.longdouble]
18+
arr_c64: npt.NDArray[np.complex64]
19+
arr_c160: npt.NDArray[np.clongdouble]
20+
21+
sparse_any: spmatrix[Any] | sparray[Any]
22+
sparse_i8: spmatrix[np.int8] | sparray[np.int8]
23+
sparse_f32: spmatrix[np.float32] | sparray[np.float32]
24+
sparse_f80: spmatrix[np.longdouble] | sparray[np.longdouble]
25+
sparse_c64: spmatrix[np.complex64] | sparray[np.complex64]
26+
sparse_c160: spmatrix[np.clongdouble] | sparray[np.clongdouble]
27+
28+
###
29+
30+
assert_type(clarkson_woodruff_transform(like_bool_2d, 2), npt.NDArray[np.int_])
31+
assert_type(clarkson_woodruff_transform(like_i64_2d, 2), npt.NDArray[np.int_])
32+
assert_type(clarkson_woodruff_transform(like_f64_2d, 2), npt.NDArray[np.float64])
33+
assert_type(clarkson_woodruff_transform(like_c128_2d, 2), npt.NDArray[np.complex128])
34+
35+
assert_type(clarkson_woodruff_transform(arr_any, 2), npt.NDArray[Any]) # type: ignore[assert-type]
36+
assert_type(clarkson_woodruff_transform(arr_i8, 2), npt.NDArray[np.int_])
37+
assert_type(clarkson_woodruff_transform(arr_f32, 2), npt.NDArray[np.float64])
38+
assert_type(clarkson_woodruff_transform(arr_c64, 2), npt.NDArray[np.complex128])
39+
assert_type(clarkson_woodruff_transform(arr_f80, 2), npt.NDArray[np.longdouble])
40+
assert_type(clarkson_woodruff_transform(arr_c160, 2), npt.NDArray[np.clongdouble])
41+
42+
assert_type(clarkson_woodruff_transform(sparse_any, 2), csc_matrix[Any])
43+
assert_type(clarkson_woodruff_transform(sparse_i8, 2), csc_matrix[np.int_])
44+
assert_type(clarkson_woodruff_transform(sparse_f32, 2), csc_matrix[np.float64])
45+
assert_type(clarkson_woodruff_transform(sparse_c64, 2), csc_matrix[np.complex128])
46+
assert_type(clarkson_woodruff_transform(sparse_f80, 2), csc_matrix[np.longdouble])
47+
assert_type(clarkson_woodruff_transform(sparse_c160, 2), csc_matrix[np.clongdouble])

0 commit comments

Comments
 (0)