Skip to content

Commit 93949d5

Browse files
committed
linalg: improved basic solve_* annotations
1 parent f679166 commit 93949d5

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

scipy-stubs/linalg/_basic.pyi

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ _InputC64: TypeAlias = onp.CanArrayND[np.complex64] | onp.SequenceND[onp.CanArra
5757
_InputC64Strict1D: TypeAlias = onp.CanArray1D[np.complex64] | Sequence[onp.CanArray0D[np.complex64]]
5858
_InputC64Strict2D: TypeAlias = onp.CanArray2D[np.complex64] | Sequence[_InputC64Strict1D]
5959

60+
_InputComplex: TypeAlias = onp.ToArrayND[op.JustComplex, np.complex128 | np.clongdouble]
61+
_InputComplexStrict1D: TypeAlias = onp.ToArrayStrict1D[op.JustComplex, np.complex128 | np.clongdouble]
62+
_InputComplexStrict2D: TypeAlias = onp.ToArrayStrict2D[op.JustComplex, np.complex128 | np.clongdouble]
63+
6064
_CoC64: TypeAlias = np.complex64 | _AsF32 | npc.integer16 | npc.integer8 | np.bool_
6165
_CoInputC64: TypeAlias = onp.CanArrayND[_CoC64] | onp.SequenceND[onp.CanArray[Any, np.dtype[_CoC64]]]
6266
_CoInputC64Strict1D: TypeAlias = onp.CanArray1D[_CoC64] | Sequence[onp.CanArray0D[_CoC64]]
@@ -193,7 +197,7 @@ def solve(
193197
) -> onp.ArrayND[np.complex64]: ...
194198
@overload # 2d ~complex128, +complex128
195199
def solve(
196-
a: onp.ToJustComplex128Strict2D,
200+
a: _InputComplexStrict2D,
197201
b: onp.ToComplexStrict1D | onp.ToComplexStrict2D,
198202
lower: bool = False,
199203
overwrite_a: bool = False,
@@ -204,7 +208,7 @@ def solve(
204208
) -> onp.Array2D[np.complex128]: ...
205209
@overload # Nd ~complex128, +complex128
206210
def solve(
207-
a: onp.ToJustComplex128_ND,
211+
a: _InputComplex,
208212
b: onp.ToComplexND,
209213
lower: bool = False,
210214
overwrite_a: bool = False,
@@ -216,7 +220,7 @@ def solve(
216220
@overload # 2d +complex128, ~complex128
217221
def solve(
218222
a: onp.ToComplexStrict2D,
219-
b: onp.ToJustComplex128Strict1D | onp.ToJustComplex128Strict2D,
223+
b: _InputComplexStrict1D | _InputComplexStrict2D,
220224
lower: bool = False,
221225
overwrite_a: bool = False,
222226
overwrite_b: bool = False,
@@ -227,7 +231,7 @@ def solve(
227231
@overload # Nd +complex128, ~complex128
228232
def solve(
229233
a: onp.ToComplexND,
230-
b: onp.ToJustComplex128_ND,
234+
b: _InputComplex,
231235
lower: bool = False,
232236
overwrite_a: bool = False,
233237
overwrite_b: bool = False,
@@ -418,7 +422,7 @@ def solve_triangular(
418422
@overload # 1d ~complex64, +complex64
419423
def solve_triangular(
420424
a: _InputC64Strict2D,
421-
b: _InputC64Strict1D | _InputF32Strict1D,
425+
b: _CoInputC64Strict1D,
422426
trans: _TransSystem = 0,
423427
lower: bool = False,
424428
unit_diagonal: bool = False,
@@ -477,7 +481,7 @@ def solve_triangular(
477481
) -> onp.ArrayND[np.complex64]: ...
478482
@overload # 1d ~complex128, +complex128
479483
def solve_triangular(
480-
a: onp.ToJustComplex128Strict2D,
484+
a: _InputComplexStrict2D,
481485
b: onp.ToComplexStrict1D,
482486
trans: _TransSystem = 0,
483487
lower: bool = False,
@@ -487,7 +491,7 @@ def solve_triangular(
487491
) -> onp.Array1D[np.complex128]: ...
488492
@overload # 2d ~complex128, +complex128
489493
def solve_triangular(
490-
a: onp.ToJustComplex128Strict2D,
494+
a: _InputComplexStrict2D,
491495
b: onp.ToComplexStrict2D,
492496
trans: _TransSystem = 0,
493497
lower: bool = False,
@@ -497,7 +501,7 @@ def solve_triangular(
497501
) -> onp.Array2D[np.complex128]: ...
498502
@overload # Nd ~complex128, +complex128
499503
def solve_triangular(
500-
a: onp.ToJustComplex128_ND,
504+
a: _InputComplex,
501505
b: onp.ToComplexND,
502506
trans: _TransSystem = 0,
503507
lower: bool = False,
@@ -508,7 +512,7 @@ def solve_triangular(
508512
@overload # 1d +complex128, ~complex128
509513
def solve_triangular(
510514
a: onp.ToComplexStrict2D,
511-
b: onp.ToJustComplex128Strict1D,
515+
b: _InputComplexStrict1D,
512516
trans: _TransSystem = 0,
513517
lower: bool = False,
514518
unit_diagonal: bool = False,
@@ -518,7 +522,7 @@ def solve_triangular(
518522
@overload # 2d +complex128, ~complex128
519523
def solve_triangular(
520524
a: onp.ToComplexStrict2D,
521-
b: onp.ToJustComplex128Strict2D,
525+
b: _InputComplexStrict2D,
522526
trans: _TransSystem = 0,
523527
lower: bool = False,
524528
unit_diagonal: bool = False,
@@ -528,7 +532,7 @@ def solve_triangular(
528532
@overload # Nd +complex128, ~complex128
529533
def solve_triangular(
530534
a: onp.ToComplexND,
531-
b: onp.ToJustComplex128_ND,
535+
b: _InputComplex,
532536
trans: _TransSystem = 0,
533537
lower: bool = False,
534538
unit_diagonal: bool = False,
@@ -795,7 +799,7 @@ def solve_banded(
795799
@overload # 1d ~complex128, +complex128
796800
def solve_banded(
797801
l_and_u: tuple[int, int],
798-
ab: onp.ToJustComplex128Strict2D,
802+
ab: _InputComplexStrict2D,
799803
b: onp.ToComplexStrict1D,
800804
overwrite_ab: bool = False,
801805
overwrite_b: bool = False,
@@ -804,7 +808,7 @@ def solve_banded(
804808
@overload # 2d ~complex128, +complex128
805809
def solve_banded(
806810
l_and_u: tuple[int, int],
807-
ab: onp.ToJustComplex128Strict2D,
811+
ab: _InputComplexStrict2D,
808812
b: onp.ToComplexStrict2D,
809813
overwrite_ab: bool = False,
810814
overwrite_b: bool = False,
@@ -813,7 +817,7 @@ def solve_banded(
813817
@overload # Nd ~complex128, +complex128
814818
def solve_banded(
815819
l_and_u: tuple[int, int],
816-
ab: onp.ToJustComplex128_ND,
820+
ab: _InputComplex,
817821
b: onp.ToComplexND,
818822
overwrite_ab: bool = False,
819823
overwrite_b: bool = False,
@@ -823,7 +827,7 @@ def solve_banded(
823827
def solve_banded(
824828
l_and_u: tuple[int, int],
825829
ab: onp.ToComplexStrict2D,
826-
b: onp.ToJustComplex128Strict1D,
830+
b: _InputComplexStrict1D,
827831
overwrite_ab: bool = False,
828832
overwrite_b: bool = False,
829833
check_finite: bool = True,
@@ -832,7 +836,7 @@ def solve_banded(
832836
def solve_banded(
833837
l_and_u: tuple[int, int],
834838
ab: onp.ToComplexStrict2D,
835-
b: onp.ToJustComplex128Strict2D,
839+
b: _InputComplexStrict2D,
836840
overwrite_ab: bool = False,
837841
overwrite_b: bool = False,
838842
check_finite: bool = True,
@@ -841,7 +845,7 @@ def solve_banded(
841845
def solve_banded(
842846
l_and_u: tuple[int, int],
843847
ab: onp.ToComplexND,
844-
b: onp.ToJustComplex128_ND,
848+
b: _InputComplex,
845849
overwrite_ab: bool = False,
846850
overwrite_b: bool = False,
847851
check_finite: bool = True,
@@ -1093,7 +1097,7 @@ def solveh_banded(
10931097
) -> onp.ArrayND[np.complex64]: ...
10941098
@overload # 1d ~complex128, +complex128
10951099
def solveh_banded(
1096-
ab: onp.ToJustComplex128Strict2D,
1100+
ab: _InputComplexStrict2D,
10971101
b: onp.ToComplexStrict1D,
10981102
overwrite_ab: bool = False,
10991103
overwrite_b: bool = False,
@@ -1102,7 +1106,7 @@ def solveh_banded(
11021106
) -> onp.Array1D[np.complex128]: ...
11031107
@overload # 2d ~complex128, +complex128
11041108
def solveh_banded(
1105-
ab: onp.ToJustComplex128Strict2D,
1109+
ab: _InputComplexStrict2D,
11061110
b: onp.ToComplexStrict2D,
11071111
overwrite_ab: bool = False,
11081112
overwrite_b: bool = False,
@@ -1111,7 +1115,7 @@ def solveh_banded(
11111115
) -> onp.Array2D[np.complex128]: ...
11121116
@overload # Nd ~complex128, +complex128
11131117
def solveh_banded(
1114-
ab: onp.ToJustComplex128_ND,
1118+
ab: _InputComplex,
11151119
b: onp.ToComplexND,
11161120
overwrite_ab: bool = False,
11171121
overwrite_b: bool = False,
@@ -1121,7 +1125,7 @@ def solveh_banded(
11211125
@overload # 1d +complex128, ~complex128
11221126
def solveh_banded(
11231127
ab: onp.ToComplexStrict2D,
1124-
b: onp.ToJustComplex128Strict1D,
1128+
b: _InputComplexStrict1D,
11251129
overwrite_ab: bool = False,
11261130
overwrite_b: bool = False,
11271131
lower: bool = False,
@@ -1130,7 +1134,7 @@ def solveh_banded(
11301134
@overload # 2d +complex128, ~complex128
11311135
def solveh_banded(
11321136
ab: onp.ToComplexStrict2D,
1133-
b: onp.ToJustComplex128Strict2D,
1137+
b: _InputComplexStrict2D,
11341138
overwrite_ab: bool = False,
11351139
overwrite_b: bool = False,
11361140
lower: bool = False,
@@ -1139,7 +1143,7 @@ def solveh_banded(
11391143
@overload # Nd +complex128, ~complex128
11401144
def solveh_banded(
11411145
ab: onp.ToComplexND,
1142-
b: onp.ToJustComplex128_ND,
1146+
b: _InputComplex,
11431147
overwrite_ab: bool = False,
11441148
overwrite_b: bool = False,
11451149
lower: bool = False,
@@ -1271,7 +1275,7 @@ def solve_toeplitz(
12711275
) -> onp.Array1D[np.complex128]: ...
12721276
@overload # 2d ~complex, +complex
12731277
def solve_toeplitz(
1274-
c_or_cr: _COrCR[onp.ToJustComplexStrict2D], b: onp.ToComplexStrict2D, check_finite: bool = True
1278+
c_or_cr: _COrCR[onp.ToJustComplexStrict1D], b: onp.ToComplexStrict2D, check_finite: bool = True
12751279
) -> onp.Array2D[np.complex128]: ...
12761280
@overload # Nd ~complex, +complex
12771281
def solve_toeplitz(
@@ -1283,7 +1287,7 @@ def solve_toeplitz(
12831287
) -> onp.Array1D[np.complex128]: ...
12841288
@overload # 2d +complex, ~complex
12851289
def solve_toeplitz(
1286-
c_or_cr: _COrCR[onp.ToComplexStrict2D], b: onp.ToJustComplexStrict2D, check_finite: bool = True
1290+
c_or_cr: _COrCR[onp.ToComplexStrict1D], b: onp.ToJustComplexStrict2D, check_finite: bool = True
12871291
) -> onp.Array2D[np.complex128]: ...
12881292
@overload # Nd +complex, ~complex
12891293
def solve_toeplitz(
@@ -1295,7 +1299,7 @@ def solve_toeplitz(
12951299
) -> onp.Array1D[np.float64 | np.complex128]: ...
12961300
@overload # 2d +complex, +complex
12971301
def solve_toeplitz(
1298-
c_or_cr: _COrCR[onp.ToComplexStrict2D], b: onp.ToComplexStrict2D, check_finite: bool = True
1302+
c_or_cr: _COrCR[onp.ToComplexStrict1D], b: onp.ToComplexStrict2D, check_finite: bool = True
12991303
) -> onp.Array2D[np.float64 | np.complex128]: ...
13001304
@overload # Nd +complex, +complex
13011305
def solve_toeplitz(
@@ -1634,7 +1638,7 @@ def solve_circulant(
16341638
outaxis: op.CanIndex = 0,
16351639
) -> onp.ArrayND[npc.inexact]: ...
16361640

1637-
#
1641+
# TODO(jorenham): improve this
16381642
@overload # floating 2d
16391643
def inv(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float2D: ...
16401644
@overload # floating
@@ -1644,7 +1648,7 @@ def inv(a: onp.ToComplexStrict2D, overwrite_a: bool = False, check_finite: bool
16441648
@overload # complexfloating
16451649
def inv(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _ComplexND: ...
16461650

1647-
#
1651+
# TODO(jorenham): improve this
16481652
@overload # floating 2d
16491653
def det(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float: ...
16501654
@overload # floating 3d
@@ -1658,7 +1662,7 @@ def det(a: onp.ToJustComplexStrict3D, overwrite_a: bool = False, check_finite: b
16581662
@overload # complexfloating
16591663
def det(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _Complex | _ComplexND: ...
16601664

1661-
#
1665+
# TODO(jorenham): improve this
16621666
@overload # (float[:, :], float[:]) -> (float[:], float[], ...)
16631667
def lstsq(
16641668
a: onp.ToFloatStrict2D,
@@ -1700,7 +1704,7 @@ def lstsq(
17001704
lapack_driver: _LapackDriver | None = None,
17011705
) -> tuple[_ComplexND, _Complex0D | _ComplexND, int, _ComplexND | None]: ...
17021706

1703-
#
1707+
# TODO(jorenham): improve this
17041708
@overload
17051709
def pinv( # (float[:, :], return_rank=False) -> float[:, :]
17061710
a: onp.ToFloatND,
@@ -1738,7 +1742,7 @@ def pinv(
17381742
check_finite: bool = True,
17391743
) -> tuple[_ComplexND, int]: ...
17401744

1741-
#
1745+
# TODO(jorenham): improve this
17421746
@overload # (float[:, :], return_rank=False) -> float[:, :]
17431747
def pinvh(
17441748
a: onp.ToFloatND,
@@ -1796,7 +1800,7 @@ def pinvh(
17961800
check_finite: bool = True,
17971801
) -> tuple[_ComplexND, int]: ...
17981802

1799-
#
1803+
# TODO(jorenham): improve this
18001804
@overload # (float[:, :], separate=True) -> (float[:, :], float[:, :])
18011805
def matrix_balance(
18021806
A: onp.ToFloatND,
@@ -1830,7 +1834,7 @@ def matrix_balance(
18301834
A: onp.ToComplexND, permute: onp.ToBool = True, scale: onp.ToBool = True, *, separate: onp.ToTrue, overwrite_a: bool = False
18311835
) -> tuple[_ComplexND, _Tuple2[_ComplexND]]: ...
18321836

1833-
#
1837+
# TODO(jorenham): improve this
18341838
@overload # floating 1d, 1d
18351839
def matmul_toeplitz(
18361840
c_or_cr: onp.ToFloatStrict1D | _Tuple2[onp.ToFloatStrict1D],

0 commit comments

Comments
 (0)