Skip to content

Commit ccf6ae5

Browse files
authored
linalg: improved inv annotations (#743)
2 parents 373827a + d7c0e9a commit ccf6ae5

File tree

2 files changed

+114
-13
lines changed

2 files changed

+114
-13
lines changed

scipy-stubs/linalg/_basic.pyi

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: disable-error-code=overload-overlap
22

3+
from collections.abc import Sequence
34
from typing import Final, Literal, TypeAlias, TypeVar, overload
45

56
import numpy as np
@@ -23,20 +24,20 @@ __all__ = [
2324
"solveh_banded",
2425
]
2526

27+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
2628
_T = TypeVar("_T")
29+
2730
_Tuple2: TypeAlias = tuple[_T, _T]
2831
_COrCR: TypeAlias = _T | _Tuple2[_T]
2932

3033
_Float: TypeAlias = npc.floating
3134
_Float0D: TypeAlias = onp.Array0D[_Float]
3235
_Float1D: TypeAlias = onp.Array1D[_Float]
33-
_Float2D: TypeAlias = onp.Array2D[_Float]
3436
_FloatND: TypeAlias = onp.ArrayND[_Float]
3537

3638
_Inexact: TypeAlias = npc.inexact
3739
_Inexact0D: TypeAlias = onp.Array0D[_Inexact]
3840
_Inexact1D: TypeAlias = onp.Array1D[_Inexact]
39-
_Inexact2D: TypeAlias = onp.Array2D[_Inexact]
4041
_InexactND: TypeAlias = onp.ArrayND[_Inexact]
4142

4243
_InputFloat: TypeAlias = onp.ToArrayND[float, np.float64 | np.longdouble | npc.integer | np.bool_]
@@ -992,15 +993,48 @@ def solve_circulant(
992993
outaxis: int = 0,
993994
) -> onp.ArrayND[npc.inexact]: ...
994995

995-
# TODO(jorenham): improve this
996-
@overload # floating 2d
997-
def inv(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float2D: ...
998-
@overload # floating
999-
def inv(a: onp.ToFloatND, overwrite_a: bool = False, check_finite: bool = True) -> _FloatND: ...
1000-
@overload # complexfloating 2d
1001-
def inv(a: onp.ToComplexStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact2D: ...
1002-
@overload # complexfloating
1003-
def inv(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _InexactND: ...
996+
#
997+
998+
@overload # 2d bool sequence
999+
def inv(a: Sequence[Sequence[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.Array2D[np.float32]: ...
1000+
@overload # Nd bool sequence
1001+
def inv(a: Sequence[onp.SequenceND[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.ArrayND[np.float32]: ...
1002+
@overload # 2d float or int sequence
1003+
def inv(
1004+
a: Sequence[Sequence[op.JustFloat | op.JustInt]], overwrite_a: bool = False, check_finite: bool = True
1005+
) -> onp.Array2D[np.float64]: ...
1006+
@overload # Nd float or int sequence
1007+
def inv(
1008+
a: Sequence[onp.SequenceND[op.JustFloat | op.JustInt]], overwrite_a: bool = False, check_finite: bool = True
1009+
) -> onp.ArrayND[np.float64]: ...
1010+
@overload # 2d complex sequence
1011+
def inv(
1012+
a: Sequence[Sequence[op.JustComplex]], overwrite_a: bool = False, check_finite: bool = True
1013+
) -> onp.Array2D[np.complex128]: ...
1014+
@overload # Nd complex sequence
1015+
def inv(
1016+
a: Sequence[onp.SequenceND[op.JustComplex]], overwrite_a: bool = False, check_finite: bool = True
1017+
) -> onp.ArrayND[np.complex128]: ...
1018+
@overload # generic shape, as float32
1019+
def inv(
1020+
a: onp.CanArrayND[np.float32 | npc.number16 | npc.integer8 | np.bool_, _ShapeT],
1021+
overwrite_a: bool = False,
1022+
check_finite: bool = True,
1023+
) -> onp.ArrayND[np.float32, _ShapeT]: ...
1024+
@overload # generic shape, as float64
1025+
def inv(
1026+
a: onp.CanArrayND[np.float64 | np.longdouble | npc.integer64 | npc.integer32, _ShapeT],
1027+
overwrite_a: bool = False,
1028+
check_finite: bool = True,
1029+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
1030+
@overload # generic shape, as complex64
1031+
def inv(
1032+
a: onp.CanArrayND[np.complex64, _ShapeT], overwrite_a: bool = False, check_finite: bool = True
1033+
) -> onp.ArrayND[np.complex64, _ShapeT]: ...
1034+
@overload # generic shape, as complex128
1035+
def inv(
1036+
a: onp.CanArrayND[np.complex128 | np.clongdouble, _ShapeT], overwrite_a: bool = False, check_finite: bool = True
1037+
) -> onp.ArrayND[np.complex128, _ShapeT]: ...
10041038

10051039
# TODO(jorenham): improve this
10061040
@overload # floating 2d

tests/linalg/test__basic.pyi

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,60 @@ import numpy as np
66
import optype.numpy as onp
77
import optype.numpy.compat as npc
88

9-
from scipy.linalg import solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
9+
from scipy.linalg import inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
10+
11+
b1_nd: onp.ArrayND[np.bool_]
1012

1113
i8_1d: onp.Array1D[np.int8]
1214
i8_2d: onp.Array2D[np.int8]
1315
i8_3d: onp.Array3D[np.int8]
16+
i8_nd: onp.ArrayND[np.int8]
17+
18+
i32_1d: onp.Array1D[np.int32]
19+
i32_2d: onp.Array2D[np.int32]
20+
i32_3d: onp.Array3D[np.int32]
21+
i32_nd: onp.ArrayND[np.int32]
1422

1523
f16_1d: onp.Array1D[np.float16]
1624
f16_2d: onp.Array2D[np.float16]
1725
f16_3d: onp.Array3D[np.float16]
26+
f16_nd: onp.ArrayND[np.float16]
1827

1928
f32_1d: onp.Array1D[np.float32]
2029
f32_2d: onp.Array2D[np.float32]
2130
f32_3d: onp.Array3D[np.float32]
31+
f32_nd: onp.ArrayND[np.float32]
2232

2333
f64_1d: onp.Array1D[np.float64]
2434
f64_2d: onp.Array2D[np.float64]
2535
f64_3d: onp.Array3D[np.float64]
36+
f64_nd: onp.ArrayND[np.float64]
2637

2738
f80_1d: onp.Array1D[np.longdouble]
2839
f80_2d: onp.Array2D[np.longdouble]
2940
f80_3d: onp.Array3D[np.longdouble]
41+
f80_nd: onp.ArrayND[np.longdouble]
3042

3143
c64_1d: onp.Array1D[np.complex64]
3244
c64_2d: onp.Array2D[np.complex64]
3345
c64_3d: onp.Array3D[np.complex64]
46+
c64_nd: onp.ArrayND[np.complex64]
3447

3548
c128_1d: onp.Array1D[np.complex128]
3649
c128_2d: onp.Array2D[np.complex128]
3750
c128_3d: onp.Array3D[np.complex128]
51+
c128_nd: onp.ArrayND[np.complex128]
3852

3953
c160_1d: onp.Array1D[np.clongdouble]
4054
c160_2d: onp.Array2D[np.clongdouble]
4155
c160_3d: onp.Array3D[np.clongdouble]
56+
c160_nd: onp.ArrayND[np.clongdouble]
57+
58+
py_b_2d: list[list[bool]]
59+
py_b_3d: list[list[list[bool]]]
60+
61+
py_i_2d: list[list[int]]
62+
py_i_3d: list[list[list[int]]]
4263

4364
py_f_1d: list[float]
4465
py_f_2d: list[list[float]]
@@ -314,4 +335,50 @@ assert_type(solve_circulant(py_c_1d, py_c_3d), onp.ArrayND[np.complex128])
314335
assert_type(solve_circulant(py_c_2d, py_c_1d), onp.ArrayND[np.complex128])
315336

316337
###
317-
# TODO(jorenham): inv, pinv, pinvh, det, lstsq, matrix_balance, matmul_toeplitz
338+
# inv
339+
340+
assert_type(inv(f32_2d), onp.Array2D[np.float32])
341+
assert_type(inv(f64_2d), onp.Array2D[np.float64])
342+
assert_type(inv(c64_2d), onp.Array2D[np.complex64])
343+
assert_type(inv(c128_2d), onp.Array2D[np.complex128])
344+
345+
assert_type(inv(py_b_2d), onp.Array2D[np.float32])
346+
assert_type(inv(py_i_2d), onp.Array2D[np.float64])
347+
assert_type(inv(py_f_2d), onp.Array2D[np.float64])
348+
assert_type(inv(py_c_2d), onp.Array2D[np.complex128])
349+
350+
assert_type(inv(f32_3d), onp.Array3D[np.float32])
351+
assert_type(inv(f64_3d), onp.Array3D[np.float64])
352+
assert_type(inv(c64_3d), onp.Array3D[np.complex64])
353+
assert_type(inv(c128_3d), onp.Array3D[np.complex128])
354+
355+
assert_type(inv(py_b_3d), onp.ArrayND[np.float32])
356+
assert_type(inv(py_i_3d), onp.ArrayND[np.float64])
357+
assert_type(inv(py_f_3d), onp.ArrayND[np.float64])
358+
assert_type(inv(py_c_3d), onp.ArrayND[np.complex128])
359+
360+
assert_type(inv(b1_nd), onp.ArrayND[np.float32])
361+
assert_type(inv(i8_nd), onp.ArrayND[np.float32])
362+
assert_type(inv(f16_nd), onp.ArrayND[np.float32])
363+
assert_type(inv(f32_nd), onp.ArrayND[np.float32])
364+
assert_type(inv(i32_nd), onp.ArrayND[np.float64])
365+
assert_type(inv(f64_nd), onp.ArrayND[np.float64])
366+
assert_type(inv(f80_nd), onp.ArrayND[np.float64])
367+
assert_type(inv(c64_nd), onp.ArrayND[np.complex64])
368+
assert_type(inv(c128_nd), onp.ArrayND[np.complex128])
369+
assert_type(inv(c160_nd), onp.ArrayND[np.complex128])
370+
371+
###
372+
# TODO(jorenham): det
373+
374+
###
375+
# TODO(jorenham): lstsq
376+
377+
###
378+
# TODO(jorenham): pinv[h]
379+
380+
###
381+
# TODO(jorenham): matrix_balance
382+
383+
###
384+
# TODO(jorenham): matmul_toeplitz

0 commit comments

Comments
 (0)