Skip to content

Commit 5539a91

Browse files
committed
linalg: improved eigh return type inference for float64
1 parent 33eb269 commit 5539a91

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

scipy-stubs/linalg/_decomp.pyi

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,22 @@ def eig(
384384
...
385385

386386
#
387-
@overload # float, eigvals_only: False = ...
387+
@overload # +float64, eigvals_only: False = ...
388+
def eigh( #
389+
a: onp.ToArrayND[float, np.float64 | np.longdouble | npc.integer64 | npc.integer32],
390+
b: onp.ToFloat64_ND | None = None,
391+
*,
392+
lower: op.CanBool = True,
393+
eigvals_only: onp.ToFalse = False,
394+
overwrite_a: op.CanBool = False,
395+
overwrite_b: op.CanBool = False,
396+
type: _EigHType = 1,
397+
check_finite: op.CanBool = True,
398+
subset_by_index: _EigHSubsetByIndex | None = None,
399+
subset_by_value: _EigHSubsetByValue | None = None,
400+
driver: _DriverEV | _DriverGV | None = None,
401+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ...
402+
@overload # +float, eigvals_only: False = ...
388403
def eigh(
389404
a: onp.ToFloatND,
390405
b: onp.ToFloatND | None = None,
@@ -399,22 +414,22 @@ def eigh(
399414
subset_by_value: _EigHSubsetByValue | None = None,
400415
driver: _DriverEV | _DriverGV | None = None,
401416
) -> tuple[_FloatND, _FloatND]: ...
402-
@overload # float, eigvals_only: True
417+
@overload # ~complex, eigvals_only: False = ...
403418
def eigh(
404-
a: onp.ToFloatND,
405-
b: onp.ToFloatND | None = None,
419+
a: onp.ToJustComplexND,
420+
b: onp.ToComplexND | None = None,
406421
*,
407422
lower: op.CanBool = True,
408-
eigvals_only: onp.ToTrue,
423+
eigvals_only: onp.ToFalse = False,
409424
overwrite_a: op.CanBool = False,
410425
overwrite_b: op.CanBool = False,
411426
type: _EigHType = 1,
412427
check_finite: op.CanBool = True,
413428
subset_by_index: _EigHSubsetByIndex | None = None,
414429
subset_by_value: _EigHSubsetByValue | None = None,
415-
driver: _DriverEV | _EigHSubsetByValue | None = None,
416-
) -> _FloatND: ...
417-
@overload # complex, eigvals_only: False = ...
430+
driver: _DriverEV | _DriverGV | None = None,
431+
) -> tuple[_FloatND, _ComplexND]: ...
432+
@overload # +complex, eigvals_only: False = ...
418433
def eigh(
419434
a: onp.ToComplexND,
420435
b: onp.ToComplexND | None = None,
@@ -429,10 +444,10 @@ def eigh(
429444
subset_by_value: _EigHSubsetByValue | None = None,
430445
driver: _DriverEV | _DriverGV | None = None,
431446
) -> tuple[_FloatND, _InexactND]: ...
432-
@overload # complex, eigvals_only: True
447+
@overload # +complex128, eigvals_only: True
433448
def eigh(
434-
a: onp.ToComplexND,
435-
b: onp.ToComplexND | None = None,
449+
a: onp.ToArrayND[float, npc.inexact80 | npc.number64 | npc.integer32],
450+
b: onp.ToComplex128_ND | None = None,
436451
*,
437452
lower: op.CanBool = True,
438453
eigvals_only: onp.ToTrue,
@@ -443,22 +458,22 @@ def eigh(
443458
subset_by_index: _EigHSubsetByIndex | None = None,
444459
subset_by_value: _EigHSubsetByValue | None = None,
445460
driver: _DriverEV | _EigHSubsetByValue | None = None,
446-
) -> _FloatND: ...
447-
@overload # complex, eigvals_only: CanBool (catch-all)
461+
) -> onp.ArrayND[np.float64]: ...
462+
@overload # +complex, eigvals_only: True
448463
def eigh(
449464
a: onp.ToComplexND,
450465
b: onp.ToComplexND | None = None,
451466
*,
452-
lower: op.CanBool,
453-
eigvals_only: op.CanBool,
467+
lower: op.CanBool = True,
468+
eigvals_only: onp.ToTrue,
454469
overwrite_a: op.CanBool = False,
455470
overwrite_b: op.CanBool = False,
456471
type: _EigHType = 1,
457472
check_finite: op.CanBool = True,
458473
subset_by_index: _EigHSubsetByIndex | None = None,
459474
subset_by_value: _EigHSubsetByValue | None = None,
460475
driver: _DriverEV | _EigHSubsetByValue | None = None,
461-
) -> _FloatND | tuple[_FloatND, _InexactND]: ...
476+
) -> _FloatND: ...
462477

463478
#
464479
@overload # float, eigvals_only: False = ..., select: _SelectA = ...

0 commit comments

Comments
 (0)