Skip to content

Commit 5595ba6

Browse files
committed
sparse: type-test csc array/matrix constructors
1 parent 5a15c4d commit 5595ba6

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

tests/sparse/test_csc.pyi

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# ruff: noqa: ERA001
2+
from typing import assert_type
3+
4+
import numpy as np
5+
6+
from ._types import ScalarType, csr_arr, csr_mat
7+
from scipy.sparse import csc_array, csc_matrix
8+
9+
scalartype: ScalarType
10+
11+
shape2: tuple[int, int]
12+
13+
ind1: np.ndarray[tuple[int], np.dtype[np.intp]]
14+
data1: np.ndarray[tuple[int], np.dtype[ScalarType]]
15+
data2: np.ndarray[tuple[int, int], np.dtype[ScalarType]]
16+
17+
csc_spec2: tuple[
18+
np.ndarray[tuple[int], np.dtype[ScalarType]],
19+
tuple[np.ndarray[tuple[int], np.dtype[np.intp]], np.ndarray[tuple[int], np.dtype[np.intp]]],
20+
]
21+
csc_spec3: tuple[
22+
np.ndarray[tuple[int], np.dtype[ScalarType]],
23+
np.ndarray[tuple[int], np.dtype[np.intp]],
24+
np.ndarray[tuple[int], np.dtype[np.intp]],
25+
]
26+
27+
###
28+
# CSC matrix constructor
29+
30+
# csc_matrix(D)
31+
assert_type(csc_matrix(data2), csc_matrix[ScalarType])
32+
assert_type(csc_matrix(data2, dtype=scalartype), csc_matrix[ScalarType])
33+
assert_type(csc_matrix(data2, dtype=bool), csc_matrix[np.bool_])
34+
assert_type(csc_matrix(data2, dtype=int), csc_matrix[np.int_])
35+
assert_type(csc_matrix(data2, dtype=float), csc_matrix[np.float64])
36+
assert_type(csc_matrix(data2, dtype=complex), csc_matrix[np.complex128])
37+
38+
# csc_matrix(S)
39+
assert_type(csc_matrix(csr_arr), csc_matrix[ScalarType])
40+
assert_type(csc_matrix(csr_mat), csc_matrix[ScalarType])
41+
42+
# csc_matrix((M, N), [dtype])
43+
assert_type(csc_matrix(shape2), csc_matrix[np.float64])
44+
assert_type(csc_matrix(shape2, dtype=scalartype), csc_matrix[ScalarType])
45+
assert_type(csc_matrix(shape2, dtype=bool), csc_matrix[np.bool_])
46+
assert_type(csc_matrix(shape2, dtype=int), csc_matrix[np.int_])
47+
assert_type(csc_matrix(shape2, dtype=float), csc_matrix[np.float64])
48+
assert_type(csc_matrix(shape2, dtype=complex), csc_matrix[np.complex128])
49+
50+
# csc_matrix((data, (row_ind, col_ind)), [shape=(M, N)])
51+
assert_type(csc_matrix(csc_spec2), csc_matrix[ScalarType])
52+
assert_type(csc_matrix(csc_spec2, shape2), csc_matrix[ScalarType])
53+
assert_type(csc_matrix(csc_spec2, shape=shape2), csc_matrix[ScalarType])
54+
55+
assert_type(csc_matrix(csc_spec2, dtype=scalartype), csc_matrix[ScalarType])
56+
assert_type(csc_matrix(csc_spec2, dtype=bool), csc_matrix[np.bool_])
57+
assert_type(csc_matrix(csc_spec2, dtype=int), csc_matrix[np.int_])
58+
assert_type(csc_matrix(csc_spec2, dtype=float), csc_matrix[np.float64])
59+
assert_type(csc_matrix(csc_spec2, dtype=complex), csc_matrix[np.complex128])
60+
61+
# csc_matrix((data, indices, indptr), [shape=(M, N)])
62+
# NOTE: mypy incorrectly infers `csc_array[Any]` here, but it is correct in pyright.
63+
assert_type(csc_matrix(csc_spec3), csc_matrix[ScalarType]) # type: ignore[assert-type]
64+
assert_type(csc_matrix(csc_spec3, shape2), csc_matrix[ScalarType]) # type: ignore[assert-type]
65+
assert_type(csc_matrix(csc_spec3, shape=shape2), csc_matrix[ScalarType]) # type: ignore[assert-type]
66+
67+
assert_type(csc_matrix(csc_spec3, dtype=scalartype), csc_matrix[ScalarType])
68+
assert_type(csc_matrix(csc_spec3, dtype=bool), csc_matrix[np.bool_])
69+
assert_type(csc_matrix(csc_spec3, dtype=int), csc_matrix[np.int_])
70+
assert_type(csc_matrix(csc_spec3, dtype=float), csc_matrix[np.float64])
71+
assert_type(csc_matrix(csc_spec3, dtype=complex), csc_matrix[np.complex128])
72+
73+
###
74+
# CSC array constructor
75+
76+
# csc_array(D)
77+
assert_type(csc_array(data2), csc_array[ScalarType])
78+
assert_type(csc_array(data2, dtype=scalartype), csc_array[ScalarType])
79+
assert_type(csc_array(data2, dtype=bool), csc_array[np.bool_])
80+
assert_type(csc_array(data2, dtype=int), csc_array[np.int_])
81+
assert_type(csc_array(data2, dtype=float), csc_array[np.float64])
82+
assert_type(csc_array(data2, dtype=complex), csc_array[np.complex128])
83+
84+
# csc_matrix(S)
85+
assert_type(csc_array(csr_arr), csc_array[ScalarType])
86+
assert_type(csc_array(csr_mat), csc_array[ScalarType])
87+
88+
# csc_array((M, N), [dtype])
89+
assert_type(csc_array(shape2), csc_array[np.float64])
90+
assert_type(csc_array(shape2, dtype=scalartype), csc_array[ScalarType])
91+
assert_type(csc_array(shape2, dtype=bool), csc_array[np.bool_])
92+
assert_type(csc_array(shape2, dtype=int), csc_array[np.int_])
93+
assert_type(csc_array(shape2, dtype=float), csc_array[np.float64])
94+
assert_type(csc_array(shape2, dtype=complex), csc_array[np.complex128])
95+
96+
# csc_array((data, (row_ind, col_ind)), [shape=(M, N)])
97+
assert_type(csc_array(csc_spec2), csc_array[ScalarType])
98+
assert_type(csc_array(csc_spec2, shape2), csc_array[ScalarType])
99+
assert_type(csc_array(csc_spec2, shape=shape2), csc_array[ScalarType])
100+
101+
assert_type(csc_array(csc_spec2, dtype=scalartype), csc_array[ScalarType])
102+
assert_type(csc_array(csc_spec2, dtype=bool), csc_array[np.bool_])
103+
assert_type(csc_array(csc_spec2, dtype=int), csc_array[np.int_])
104+
assert_type(csc_array(csc_spec2, dtype=float), csc_array[np.float64])
105+
assert_type(csc_array(csc_spec2, dtype=complex), csc_array[np.complex128])
106+
107+
# csc_array((data, indices, indptr), [shape=(M, N)])
108+
# NOTE: mypy incorrectly infers `csc_array[Any]` here, but it is correct in pyright.
109+
assert_type(csc_array(csc_spec3), csc_array[ScalarType]) # type: ignore[assert-type]
110+
assert_type(csc_array(csc_spec3, shape2), csc_array[ScalarType]) # type: ignore[assert-type]
111+
assert_type(csc_array(csc_spec3, shape=shape2), csc_array[ScalarType]) # type: ignore[assert-type]
112+
113+
assert_type(csc_array(csc_spec3, dtype=scalartype), csc_array[ScalarType])
114+
assert_type(csc_array(csc_spec3, dtype=bool), csc_array[np.bool_])
115+
assert_type(csc_array(csc_spec3, dtype=int), csc_array[np.int_])
116+
assert_type(csc_array(csc_spec3, dtype=float), csc_array[np.float64])
117+
assert_type(csc_array(csc_spec3, dtype=complex), csc_array[np.complex128])

0 commit comments

Comments
 (0)