diff --git a/scipy-stubs/linalg/_basic.pyi b/scipy-stubs/linalg/_basic.pyi index e9e96251..570e3d49 100644 --- a/scipy-stubs/linalg/_basic.pyi +++ b/scipy-stubs/linalg/_basic.pyi @@ -994,7 +994,6 @@ def solve_circulant( ) -> onp.ArrayND[npc.inexact]: ... # - @overload # 2d bool sequence def inv(a: Sequence[Sequence[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.Array2D[np.float32]: ... @overload # Nd bool sequence @@ -1036,19 +1035,35 @@ def inv( a: onp.CanArrayND[np.complex128 | np.clongdouble, _ShapeT], overwrite_a: bool = False, check_finite: bool = True ) -> onp.ArrayND[np.complex128, _ShapeT]: ... -# TODO(jorenham): improve this -@overload # floating 2d -def det(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float: ... -@overload # floating 3d -def det(a: onp.ToFloatStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _Float1D: ... -@overload # floating -def det(a: onp.ToFloatND, overwrite_a: bool = False, check_finite: bool = True) -> _Float | _FloatND: ... -@overload # complexfloating 2d -def det(a: onp.ToJustComplexStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact1D: ... -@overload # complexfloating 3d -def det(a: onp.ToJustComplexStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _InexactND: ... -@overload # complexfloating -def det(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact | _InexactND: ... +# NOTE: The order of the overloads has been carefully chosen to avoid triggering a Pyright bug. +@overload # +float64 2d +def det(a: onp.ToFloat64Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64: ... +@overload # complex128 | complex64 2d +def det( + a: onp.ToArrayStrict2D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> np.complex128: ... +@overload # +float64 3d +def det(a: onp.ToFloat64Strict3D, overwrite_a: bool = False, check_finite: bool = True) -> onp.Array1D[np.float64]: ... +@overload # complex128 | complex64 3d +def det( + a: onp.ToArrayStrict3D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> onp.Array1D[np.complex128]: ... +@overload # +float64 ND +def det(a: onp.ToFloat64_ND, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | onp.ArrayND[np.float64]: ... +@overload # complex128 | complex64 Nd +def det( + a: onp.ToArrayND[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> np.complex128 | onp.ArrayND[np.complex128]: ... +@overload # +complex128 2d +def det(a: onp.ToComplex128Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | np.complex128: ... +@overload # +complex128 3d +def det( + a: onp.ToComplex128Strict3D, overwrite_a: bool = False, check_finite: bool = True +) -> onp.Array1D[np.float64 | np.complex128]: ... +@overload # +complex128 Nd +def det( + a: onp.ToComplex128_ND, overwrite_a: bool = False, check_finite: bool = True +) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... # TODO(jorenham): improve this @overload # (float[:, :], float[:]) -> (float[:], float[], ...) diff --git a/tests/linalg/test__basic.pyi b/tests/linalg/test__basic.pyi index eb30c91a..55eba162 100644 --- a/tests/linalg/test__basic.pyi +++ b/tests/linalg/test__basic.pyi @@ -6,7 +6,7 @@ import numpy as np import optype.numpy as onp import optype.numpy.compat as npc -from scipy.linalg import inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular +from scipy.linalg import det, inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular b1_nd: onp.ArrayND[np.bool_] @@ -369,7 +369,27 @@ assert_type(inv(c128_nd), onp.ArrayND[np.complex128]) assert_type(inv(c160_nd), onp.ArrayND[np.complex128]) ### -# TODO(jorenham): det +# det + +assert_type(det(f32_2d), np.float64) +assert_type(det(f64_2d), np.float64) +assert_type(det(c64_2d), np.complex128) +assert_type(det(c128_2d), np.complex128) + +assert_type(det(py_b_2d), np.float64) +assert_type(det(py_i_2d), np.float64) +assert_type(det(py_f_2d), np.float64) +assert_type(det(py_c_2d), np.complex128) + +assert_type(det(f32_3d), onp.Array1D[np.float64]) +assert_type(det(f64_3d), onp.Array1D[np.float64]) +assert_type(det(c64_3d), onp.Array1D[np.complex128]) +assert_type(det(c128_3d), onp.Array1D[np.complex128]) + +assert_type(det(py_b_3d), onp.Array1D[np.float64]) +assert_type(det(py_i_3d), onp.Array1D[np.float64]) +assert_type(det(py_f_3d), onp.Array1D[np.float64]) +assert_type(det(py_c_3d), onp.Array1D[np.complex128]) ### # TODO(jorenham): lstsq