Skip to content

Commit d7c0e9a

Browse files
committed
linalg: type-tests for inv
1 parent 241509b commit d7c0e9a

File tree

1 file changed

+69
-2
lines changed

1 file changed

+69
-2
lines changed

tests/linalg/test__basic.pyi

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,60 @@ import numpy as np
66
import optype.numpy as onp
77
import optype.numpy.compat as npc
88

9-
from scipy.linalg import solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
9+
from scipy.linalg import inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
10+
11+
b1_nd: onp.ArrayND[np.bool_]
1012

1113
i8_1d: onp.Array1D[np.int8]
1214
i8_2d: onp.Array2D[np.int8]
1315
i8_3d: onp.Array3D[np.int8]
16+
i8_nd: onp.ArrayND[np.int8]
17+
18+
i32_1d: onp.Array1D[np.int32]
19+
i32_2d: onp.Array2D[np.int32]
20+
i32_3d: onp.Array3D[np.int32]
21+
i32_nd: onp.ArrayND[np.int32]
1422

1523
f16_1d: onp.Array1D[np.float16]
1624
f16_2d: onp.Array2D[np.float16]
1725
f16_3d: onp.Array3D[np.float16]
26+
f16_nd: onp.ArrayND[np.float16]
1827

1928
f32_1d: onp.Array1D[np.float32]
2029
f32_2d: onp.Array2D[np.float32]
2130
f32_3d: onp.Array3D[np.float32]
31+
f32_nd: onp.ArrayND[np.float32]
2232

2333
f64_1d: onp.Array1D[np.float64]
2434
f64_2d: onp.Array2D[np.float64]
2535
f64_3d: onp.Array3D[np.float64]
36+
f64_nd: onp.ArrayND[np.float64]
2637

2738
f80_1d: onp.Array1D[np.longdouble]
2839
f80_2d: onp.Array2D[np.longdouble]
2940
f80_3d: onp.Array3D[np.longdouble]
41+
f80_nd: onp.ArrayND[np.longdouble]
3042

3143
c64_1d: onp.Array1D[np.complex64]
3244
c64_2d: onp.Array2D[np.complex64]
3345
c64_3d: onp.Array3D[np.complex64]
46+
c64_nd: onp.ArrayND[np.complex64]
3447

3548
c128_1d: onp.Array1D[np.complex128]
3649
c128_2d: onp.Array2D[np.complex128]
3750
c128_3d: onp.Array3D[np.complex128]
51+
c128_nd: onp.ArrayND[np.complex128]
3852

3953
c160_1d: onp.Array1D[np.clongdouble]
4054
c160_2d: onp.Array2D[np.clongdouble]
4155
c160_3d: onp.Array3D[np.clongdouble]
56+
c160_nd: onp.ArrayND[np.clongdouble]
57+
58+
py_b_2d: list[list[bool]]
59+
py_b_3d: list[list[list[bool]]]
60+
61+
py_i_2d: list[list[int]]
62+
py_i_3d: list[list[list[int]]]
4263

4364
py_f_1d: list[float]
4465
py_f_2d: list[list[float]]
@@ -314,4 +335,50 @@ assert_type(solve_circulant(py_c_1d, py_c_3d), onp.ArrayND[np.complex128])
314335
assert_type(solve_circulant(py_c_2d, py_c_1d), onp.ArrayND[np.complex128])
315336

316337
###
317-
# TODO(jorenham): inv, pinv, pinvh, det, lstsq, matrix_balance, matmul_toeplitz
338+
# inv
339+
340+
assert_type(inv(f32_2d), onp.Array2D[np.float32])
341+
assert_type(inv(f64_2d), onp.Array2D[np.float64])
342+
assert_type(inv(c64_2d), onp.Array2D[np.complex64])
343+
assert_type(inv(c128_2d), onp.Array2D[np.complex128])
344+
345+
assert_type(inv(py_b_2d), onp.Array2D[np.float32])
346+
assert_type(inv(py_i_2d), onp.Array2D[np.float64])
347+
assert_type(inv(py_f_2d), onp.Array2D[np.float64])
348+
assert_type(inv(py_c_2d), onp.Array2D[np.complex128])
349+
350+
assert_type(inv(f32_3d), onp.Array3D[np.float32])
351+
assert_type(inv(f64_3d), onp.Array3D[np.float64])
352+
assert_type(inv(c64_3d), onp.Array3D[np.complex64])
353+
assert_type(inv(c128_3d), onp.Array3D[np.complex128])
354+
355+
assert_type(inv(py_b_3d), onp.ArrayND[np.float32])
356+
assert_type(inv(py_i_3d), onp.ArrayND[np.float64])
357+
assert_type(inv(py_f_3d), onp.ArrayND[np.float64])
358+
assert_type(inv(py_c_3d), onp.ArrayND[np.complex128])
359+
360+
assert_type(inv(b1_nd), onp.ArrayND[np.float32])
361+
assert_type(inv(i8_nd), onp.ArrayND[np.float32])
362+
assert_type(inv(f16_nd), onp.ArrayND[np.float32])
363+
assert_type(inv(f32_nd), onp.ArrayND[np.float32])
364+
assert_type(inv(i32_nd), onp.ArrayND[np.float64])
365+
assert_type(inv(f64_nd), onp.ArrayND[np.float64])
366+
assert_type(inv(f80_nd), onp.ArrayND[np.float64])
367+
assert_type(inv(c64_nd), onp.ArrayND[np.complex64])
368+
assert_type(inv(c128_nd), onp.ArrayND[np.complex128])
369+
assert_type(inv(c160_nd), onp.ArrayND[np.complex128])
370+
371+
###
372+
# TODO(jorenham): det
373+
374+
###
375+
# TODO(jorenham): lstsq
376+
377+
###
378+
# TODO(jorenham): pinv[h]
379+
380+
###
381+
# TODO(jorenham): matrix_balance
382+
383+
###
384+
# TODO(jorenham): matmul_toeplitz

0 commit comments

Comments
 (0)