Skip to content

Commit a1b36b2

Browse files
authored
linalg: improved eigh return type inference for float64 (#755)
2 parents 8fea2ac + 5539a91 commit a1b36b2

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

scipy-stubs/linalg/_decomp.pyi

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,22 @@ def eig(
384384
...
385385

386386
#
387-
@overload # float, eigvals_only: False = ...
387+
@overload # +float64, eigvals_only: False = ...
388+
def eigh( #
389+
a: onp.ToArrayND[float, np.float64 | np.longdouble | npc.integer64 | npc.integer32],
390+
b: onp.ToFloat64_ND | None = None,
391+
*,
392+
lower: op.CanBool = True,
393+
eigvals_only: onp.ToFalse = False,
394+
overwrite_a: op.CanBool = False,
395+
overwrite_b: op.CanBool = False,
396+
type: _EigHType = 1,
397+
check_finite: op.CanBool = True,
398+
subset_by_index: _EigHSubsetByIndex | None = None,
399+
subset_by_value: _EigHSubsetByValue | None = None,
400+
driver: _DriverEV | _DriverGV | None = None,
401+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ...
402+
@overload # +float, eigvals_only: False = ...
388403
def eigh(
389404
a: onp.ToFloatND,
390405
b: onp.ToFloatND | None = None,
@@ -399,22 +414,22 @@ def eigh(
399414
subset_by_value: _EigHSubsetByValue | None = None,
400415
driver: _DriverEV | _DriverGV | None = None,
401416
) -> tuple[_FloatND, _FloatND]: ...
402-
@overload # float, eigvals_only: True
417+
@overload # ~complex, eigvals_only: False = ...
403418
def eigh(
404-
a: onp.ToFloatND,
405-
b: onp.ToFloatND | None = None,
419+
a: onp.ToJustComplexND,
420+
b: onp.ToComplexND | None = None,
406421
*,
407422
lower: op.CanBool = True,
408-
eigvals_only: onp.ToTrue,
423+
eigvals_only: onp.ToFalse = False,
409424
overwrite_a: op.CanBool = False,
410425
overwrite_b: op.CanBool = False,
411426
type: _EigHType = 1,
412427
check_finite: op.CanBool = True,
413428
subset_by_index: _EigHSubsetByIndex | None = None,
414429
subset_by_value: _EigHSubsetByValue | None = None,
415-
driver: _DriverEV | _EigHSubsetByValue | None = None,
416-
) -> _FloatND: ...
417-
@overload # complex, eigvals_only: False = ...
430+
driver: _DriverEV | _DriverGV | None = None,
431+
) -> tuple[_FloatND, _ComplexND]: ...
432+
@overload # +complex, eigvals_only: False = ...
418433
def eigh(
419434
a: onp.ToComplexND,
420435
b: onp.ToComplexND | None = None,
@@ -429,10 +444,10 @@ def eigh(
429444
subset_by_value: _EigHSubsetByValue | None = None,
430445
driver: _DriverEV | _DriverGV | None = None,
431446
) -> tuple[_FloatND, _InexactND]: ...
432-
@overload # complex, eigvals_only: True
447+
@overload # +complex128, eigvals_only: True
433448
def eigh(
434-
a: onp.ToComplexND,
435-
b: onp.ToComplexND | None = None,
449+
a: onp.ToArrayND[float, npc.inexact80 | npc.number64 | npc.integer32],
450+
b: onp.ToComplex128_ND | None = None,
436451
*,
437452
lower: op.CanBool = True,
438453
eigvals_only: onp.ToTrue,
@@ -443,22 +458,22 @@ def eigh(
443458
subset_by_index: _EigHSubsetByIndex | None = None,
444459
subset_by_value: _EigHSubsetByValue | None = None,
445460
driver: _DriverEV | _EigHSubsetByValue | None = None,
446-
) -> _FloatND: ...
447-
@overload # complex, eigvals_only: CanBool (catch-all)
461+
) -> onp.ArrayND[np.float64]: ...
462+
@overload # +complex, eigvals_only: True
448463
def eigh(
449464
a: onp.ToComplexND,
450465
b: onp.ToComplexND | None = None,
451466
*,
452-
lower: op.CanBool,
453-
eigvals_only: op.CanBool,
467+
lower: op.CanBool = True,
468+
eigvals_only: onp.ToTrue,
454469
overwrite_a: op.CanBool = False,
455470
overwrite_b: op.CanBool = False,
456471
type: _EigHType = 1,
457472
check_finite: op.CanBool = True,
458473
subset_by_index: _EigHSubsetByIndex | None = None,
459474
subset_by_value: _EigHSubsetByValue | None = None,
460475
driver: _DriverEV | _EigHSubsetByValue | None = None,
461-
) -> _FloatND | tuple[_FloatND, _InexactND]: ...
476+
) -> _FloatND: ...
462477

463478
#
464479
@overload # float, eigvals_only: False = ..., select: _SelectA = ...

scipy-stubs/linalg/_matfuncs.pyi

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ _Complex128ND: TypeAlias = onp.ArrayND[np.complex128]
4141
_ComplexND: TypeAlias = onp.ArrayND[npc.complexfloating]
4242
_InexactND: TypeAlias = onp.ArrayND[npc.inexact]
4343

44+
_AsFloat64ND: TypeAlias = onp.ToArrayND[float, np.float64 | npc.integer | np.bool_]
45+
4446
###
4547

4648
eps: Final[np.float64] = ... # undocumented
@@ -77,7 +79,7 @@ def fractional_matrix_power(A: onp.ToComplexND, t: onp.ToFloat) -> onp.ArrayND[A
7779

7880
#
7981
@overload
80-
def sqrtm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
82+
def sqrtm(A: _AsFloat64ND) -> _Float64ND: ...
8183
@overload
8284
def sqrtm(A: onp.ToFloatND) -> _FloatND: ...
8385
@overload
@@ -106,7 +108,7 @@ def sqrtm(A: onp.ToComplexND, disp: onp.ToFalse, blocksize: int) -> tuple[_Inexa
106108

107109
# NOTE: return dtype depends on the sign of the values
108110
@overload # +integer | ~float64
109-
def logm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND | _Complex128ND: ...
111+
def logm(A: _AsFloat64ND) -> _Float64ND | _Complex128ND: ...
110112
@overload # +floating
111113
def logm(A: onp.ToFloatND) -> _InexactND: ...
112114
@overload # ~complex128
@@ -127,7 +129,7 @@ def logm(A: onp.ToComplexND, disp: onp.ToFalse) -> tuple[_InexactND, float]: ...
127129

128130
#
129131
@overload # +integer | ~float64
130-
def expm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
132+
def expm(A: _AsFloat64ND) -> _Float64ND: ...
131133
@overload # +floating
132134
def expm(A: onp.ToFloatND) -> _FloatND: ...
133135
@overload # ~complex128
@@ -142,7 +144,7 @@ def _exp_sinch(x: onp.ArrayND[_ComplexT, _ShapeT]) -> onp.ArrayND[_ComplexT, _Sh
142144

143145
#
144146
@overload # +integer | ~float64
145-
def cosm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
147+
def cosm(A: _AsFloat64ND) -> _Float64ND: ...
146148
@overload # +floating
147149
def cosm(A: onp.ToFloatND) -> _FloatND: ...
148150
@overload # ~complex128
@@ -154,7 +156,7 @@ def cosm(A: onp.ToComplexND) -> _InexactND: ...
154156

155157
#
156158
@overload # +integer | ~float64
157-
def sinm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
159+
def sinm(A: _AsFloat64ND) -> _Float64ND: ...
158160
@overload # +floating
159161
def sinm(A: onp.ToFloatND) -> _FloatND: ...
160162
@overload # ~complex128
@@ -166,7 +168,7 @@ def sinm(A: onp.ToComplexND) -> _InexactND: ...
166168

167169
#
168170
@overload # +integer | ~float64
169-
def tanm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
171+
def tanm(A: _AsFloat64ND) -> _Float64ND: ...
170172
@overload # +floating
171173
def tanm(A: onp.ToFloatND) -> _FloatND: ...
172174
@overload # ~complex128
@@ -178,7 +180,7 @@ def tanm(A: onp.ToComplexND) -> _InexactND: ...
178180

179181
#
180182
@overload # +integer | ~float64
181-
def coshm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
183+
def coshm(A: _AsFloat64ND) -> _Float64ND: ...
182184
@overload # +floating
183185
def coshm(A: onp.ToFloatND) -> _FloatND: ...
184186
@overload # ~complex128
@@ -190,7 +192,7 @@ def coshm(A: onp.ToComplexND) -> _InexactND: ...
190192

191193
#
192194
@overload # +integer | ~float64
193-
def sinhm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
195+
def sinhm(A: _AsFloat64ND) -> _Float64ND: ...
194196
@overload # +floating
195197
def sinhm(A: onp.ToFloatND) -> _FloatND: ...
196198
@overload # ~complex128
@@ -202,7 +204,7 @@ def sinhm(A: onp.ToComplexND) -> _InexactND: ...
202204

203205
#
204206
@overload # +integer | ~float64
205-
def tanhm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
207+
def tanhm(A: _AsFloat64ND) -> _Float64ND: ...
206208
@overload # +floating
207209
def tanhm(A: onp.ToFloatND) -> _FloatND: ...
208210
@overload # ~complex128
@@ -228,7 +230,7 @@ def funm(A: onp.CanArrayND[_InexactT], func: _FuncND[_InexactT], disp: onp.ToFal
228230

229231
#
230232
@overload # +float64
231-
def signm(A: onp.ToIntND | onp.ToFloat64_ND) -> _Float64ND: ...
233+
def signm(A: _AsFloat64ND) -> _Float64ND: ...
232234
@overload # +floating
233235
def signm(A: onp.ToFloatND) -> _FloatND: ...
234236
@overload # +complexfloating
@@ -248,7 +250,7 @@ def signm(A: onp.ToComplexND, disp: onp.ToTrue) -> _InexactND: ...
248250
def signm(A: onp.ToComplexND, disp: onp.ToFalse) -> tuple[_InexactND, np.float64]: ...
249251

250252
#
251-
@overload # +integer, +integer
253+
@overload # +integer | ~float64
252254
def khatri_rao(a: onp.ToIntND, b: onp.ToIntND) -> _IntND: ...
253255
@overload # +float64, ~float64
254256
def khatri_rao(a: onp.ToFloat64_ND, b: onp.ToJustFloat64_ND) -> _Float64ND: ...

0 commit comments

Comments
 (0)