Skip to content

Commit 9f32711

Browse files
authored
linalg: improved det annotations (#744)
2 parents 40829a6 + 4a80fa1 commit 9f32711

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

scipy-stubs/linalg/_basic.pyi

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,6 @@ def solve_circulant(
994994
) -> onp.ArrayND[npc.inexact]: ...
995995

996996
#
997-
998997
@overload # 2d bool sequence
999998
def inv(a: Sequence[Sequence[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.Array2D[np.float32]: ...
1000999
@overload # Nd bool sequence
@@ -1036,19 +1035,35 @@ def inv(
10361035
a: onp.CanArrayND[np.complex128 | np.clongdouble, _ShapeT], overwrite_a: bool = False, check_finite: bool = True
10371036
) -> onp.ArrayND[np.complex128, _ShapeT]: ...
10381037

1039-
# TODO(jorenham): improve this
1040-
@overload # floating 2d
1041-
def det(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float: ...
1042-
@overload # floating 3d
1043-
def det(a: onp.ToFloatStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _Float1D: ...
1044-
@overload # floating
1045-
def det(a: onp.ToFloatND, overwrite_a: bool = False, check_finite: bool = True) -> _Float | _FloatND: ...
1046-
@overload # complexfloating 2d
1047-
def det(a: onp.ToJustComplexStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact1D: ...
1048-
@overload # complexfloating 3d
1049-
def det(a: onp.ToJustComplexStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _InexactND: ...
1050-
@overload # complexfloating
1051-
def det(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact | _InexactND: ...
1038+
# NOTE: The order of the overloads has been carefully chosen to avoid triggering a Pyright bug.
1039+
@overload # +float64 2d
1040+
def det(a: onp.ToFloat64Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64: ...
1041+
@overload # complex128 | complex64 2d
1042+
def det(
1043+
a: onp.ToArrayStrict2D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True
1044+
) -> np.complex128: ...
1045+
@overload # +float64 3d
1046+
def det(a: onp.ToFloat64Strict3D, overwrite_a: bool = False, check_finite: bool = True) -> onp.Array1D[np.float64]: ...
1047+
@overload # complex128 | complex64 3d
1048+
def det(
1049+
a: onp.ToArrayStrict3D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True
1050+
) -> onp.Array1D[np.complex128]: ...
1051+
@overload # +float64 ND
1052+
def det(a: onp.ToFloat64_ND, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | onp.ArrayND[np.float64]: ...
1053+
@overload # complex128 | complex64 Nd
1054+
def det(
1055+
a: onp.ToArrayND[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True
1056+
) -> np.complex128 | onp.ArrayND[np.complex128]: ...
1057+
@overload # +complex128 2d
1058+
def det(a: onp.ToComplex128Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | np.complex128: ...
1059+
@overload # +complex128 3d
1060+
def det(
1061+
a: onp.ToComplex128Strict3D, overwrite_a: bool = False, check_finite: bool = True
1062+
) -> onp.Array1D[np.float64 | np.complex128]: ...
1063+
@overload # +complex128 Nd
1064+
def det(
1065+
a: onp.ToComplex128_ND, overwrite_a: bool = False, check_finite: bool = True
1066+
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
10521067

10531068
# TODO(jorenham): improve this
10541069
@overload # (float[:, :], float[:]) -> (float[:], float[], ...)

tests/linalg/test__basic.pyi

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import numpy as np
66
import optype.numpy as onp
77
import optype.numpy.compat as npc
88

9-
from scipy.linalg import inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
9+
from scipy.linalg import det, inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular
1010

1111
b1_nd: onp.ArrayND[np.bool_]
1212

@@ -369,7 +369,27 @@ assert_type(inv(c128_nd), onp.ArrayND[np.complex128])
369369
assert_type(inv(c160_nd), onp.ArrayND[np.complex128])
370370

371371
###
372-
# TODO(jorenham): det
372+
# det
373+
374+
assert_type(det(f32_2d), np.float64)
375+
assert_type(det(f64_2d), np.float64)
376+
assert_type(det(c64_2d), np.complex128)
377+
assert_type(det(c128_2d), np.complex128)
378+
379+
assert_type(det(py_b_2d), np.float64)
380+
assert_type(det(py_i_2d), np.float64)
381+
assert_type(det(py_f_2d), np.float64)
382+
assert_type(det(py_c_2d), np.complex128)
383+
384+
assert_type(det(f32_3d), onp.Array1D[np.float64])
385+
assert_type(det(f64_3d), onp.Array1D[np.float64])
386+
assert_type(det(c64_3d), onp.Array1D[np.complex128])
387+
assert_type(det(c128_3d), onp.Array1D[np.complex128])
388+
389+
assert_type(det(py_b_3d), onp.Array1D[np.float64])
390+
assert_type(det(py_i_3d), onp.Array1D[np.float64])
391+
assert_type(det(py_f_3d), onp.Array1D[np.float64])
392+
assert_type(det(py_c_3d), onp.Array1D[np.complex128])
373393

374394
###
375395
# TODO(jorenham): lstsq

0 commit comments

Comments
 (0)