Skip to content

🤡 workaround for path-dependent pyright bug by replicating CanArrayND #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions scipy-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from collections.abc import Sequence
from types import TracebackType
from typing import SupportsIndex, TypeAlias, final, type_check_only
from typing import Any, Protocol, SupportsIndex, TypeAlias, final, type_check_only
from typing_extensions import TypeVar

__all__ = "AnyShape", "ExitMixin"
import numpy as np

__all__ = "AnyShape", "CanArrayND", "ExitMixin"

# helper mixins
@type_check_only
Expand All @@ -14,3 +17,17 @@ class ExitMixin:

# equivalent to `numpy._typing._shape._ShapeLike`
AnyShape: TypeAlias = SupportsIndex | Sequence[SupportsIndex]

# NOTE: For some reason, `onp.CanArrayND` isn't understood by pyright when running `uv run pyright tests`, even though it works
# fine when running `uv run pyright` in the root directory (same story for basedpyright). By copying the definition here, these
# Pyright won't report false positives (no idea why though), so this is but a workaround.
# https://github.com/jorenham/optype/blob/abf1758/optype/numpy/_array.py#L124-L133
# TODO(jorenham): Remove this workaround once the issue is fixed in Pyright.

_SCT_co = TypeVar("_SCT_co", bound=np.generic, covariant=True)
_NDT_co = TypeVar("_NDT_co", bound=tuple[int, ...], default=tuple[Any, ...], covariant=True)

@type_check_only
class CanArrayND(Protocol[_SCT_co, _NDT_co]):
def __len__(self, /) -> int: ...
def __array__(self, /) -> np.ndarray[_NDT_co, np.dtype[_SCT_co]]: ...
23 changes: 13 additions & 10 deletions scipy-stubs/fft/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,26 @@ import numpy as np
import optype.numpy as onp
import optype.numpy.compat as npc

from scipy._typing import AnyShape
from scipy._typing import (
AnyShape,
CanArrayND, # path-dependent Pyright bug workaround
)

_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])

_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
_Unused: TypeAlias = Never # not used by scipy

_AsFloat32: TypeAlias = onp.CanArrayND[npc.floating32, _ShapeT]
_AsFloat64: TypeAlias = onp.CanArrayND[np.bool_ | npc.integer | npc.floating64, _ShapeT]
_AsFloat80: TypeAlias = onp.CanArrayND[np.longdouble, _ShapeT]
_CoInteger: TypeAlias = npc.integer | np.bool_
_AsFloat32: TypeAlias = CanArrayND[npc.floating32, _ShapeT]
_AsFloat64: TypeAlias = CanArrayND[npc.floating64 | _CoInteger, _ShapeT]
_AsFloat80: TypeAlias = CanArrayND[np.longdouble, _ShapeT]
_AsComplex64: TypeAlias = CanArrayND[npc.inexact32, _ShapeT]
_AsComplex128: TypeAlias = CanArrayND[npc.inexact64 | _CoInteger, _ShapeT]
_AsComplex160: TypeAlias = CanArrayND[np.longdouble | np.clongdouble, _ShapeT]

_AsComplex64: TypeAlias = onp.CanArrayND[np.float32 | np.complex64, _ShapeT]
_AsComplex128: TypeAlias = onp.CanArrayND[np.bool_ | npc.integer | np.float64 | np.complex128, _ShapeT]
_AsComplex160: TypeAlias = onp.CanArrayND[np.longdouble | np.clongdouble, _ShapeT]

_ToFloat64_ND: TypeAlias = onp.ToArrayND[float, np.bool_ | npc.integer | np.float64]
_ToComplex128_ND: TypeAlias = onp.ToArrayND[complex, np.bool_ | npc.integer | npc.inexact64]
_ToFloat64_ND: TypeAlias = onp.ToArrayND[float, npc.floating64 | _CoInteger]
_ToComplex128_ND: TypeAlias = onp.ToArrayND[complex, npc.inexact64 | _CoInteger]

###
# 1-D
Expand Down
18 changes: 10 additions & 8 deletions scipy-stubs/fft/_fftlog.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ import numpy as np
import optype.numpy as onp
import optype.numpy.compat as npc

from scipy._typing import CanArrayND # path-dependent Pyright bug workaround

__all__ = ["fht", "fhtoffset", "ifht"]

_FloatT = TypeVar("_FloatT", bound=np.float32 | np.float64 | np.longdouble)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])

###

@overload
def fht(
a: onp.CanArrayND[_FloatT, _ShapeT], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[_FloatT, _ShapeT]: ...
@overload
def fht(
a: Sequence[float], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
Expand All @@ -29,16 +27,16 @@ def fht(
a: Sequence[Sequence[Sequence[float]]], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.Array3D[np.float64]: ...
@overload
def fht(
a: CanArrayND[_FloatT, _ShapeT], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[_FloatT, _ShapeT]: ...
@overload
def fht(
a: onp.ToFloatND, dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[npc.floating]: ...

#
@overload
def ifht(
A: onp.CanArrayND[_FloatT, _ShapeT], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[_FloatT, _ShapeT]: ...
@overload
def ifht(
A: Sequence[float], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.Array1D[np.float64]: ...
Expand All @@ -51,6 +49,10 @@ def ifht(
A: Sequence[Sequence[Sequence[float]]], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.Array3D[np.float64]: ...
@overload
def ifht(
A: CanArrayND[_FloatT, _ShapeT], dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[_FloatT, _ShapeT]: ...
@overload
def ifht(
A: onp.ToFloatND, dln: onp.ToFloat, mu: onp.ToFloat, offset: onp.ToFloat = 0.0, bias: onp.ToFloat = 0.0
) -> onp.ArrayND[npc.floating]: ...
Expand Down
37 changes: 20 additions & 17 deletions scipy-stubs/fft/_realtransforms.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import optype.numpy as onp
import optype.numpy.compat as npc

from ._typing import DCTType, NormalizationMode
from scipy._typing import AnyShape
from scipy._typing import (
AnyShape,
CanArrayND, # path-dependent Pyright bug workaround
)

__all__ = ["dct", "dctn", "dst", "dstn", "idct", "idctn", "idst", "idstn"]

Expand All @@ -21,7 +24,7 @@ _FloatND: TypeAlias = onp.ArrayND[np.float32 | np.float64 | np.longdouble] # do

@overload
def dctn(
x: onp.CanArrayND[npc.integer, _ShapeT],
x: CanArrayND[npc.integer, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand All @@ -33,7 +36,7 @@ def dctn(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def dctn(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand Down Expand Up @@ -83,7 +86,7 @@ def dctn(
#
@overload
def idctn(
x: onp.CanArrayND[npc.integer, _ShapeT],
x: CanArrayND[npc.integer, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand All @@ -94,7 +97,7 @@ def idctn(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def idctn(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand Down Expand Up @@ -140,7 +143,7 @@ def idctn(
#
@overload
def dstn(
x: onp.CanArrayND[npc.integer, _ShapeT],
x: CanArrayND[npc.integer, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand All @@ -151,7 +154,7 @@ def dstn(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def dstn(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand Down Expand Up @@ -197,7 +200,7 @@ def dstn(
#
@overload
def idstn(
x: onp.CanArrayND[npc.integer, _ShapeT],
x: CanArrayND[npc.integer, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand All @@ -208,7 +211,7 @@ def idstn(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def idstn(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
s: _ToIntOrND | None = None,
axes: AnyShape | None = None,
Expand Down Expand Up @@ -254,7 +257,7 @@ def idstn(
#
@overload
def dct(
x: onp.CanArrayND[np.integer, _ShapeT],
x: CanArrayND[np.integer, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand All @@ -265,7 +268,7 @@ def dct(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def dct(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand Down Expand Up @@ -311,7 +314,7 @@ def dct(
#
@overload
def idct(
x: onp.CanArrayND[np.integer, _ShapeT],
x: CanArrayND[np.integer, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand All @@ -322,7 +325,7 @@ def idct(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def idct(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand Down Expand Up @@ -368,7 +371,7 @@ def idct(
#
@overload
def dst(
x: onp.CanArrayND[np.integer, _ShapeT],
x: CanArrayND[np.integer, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand All @@ -379,7 +382,7 @@ def dst(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def dst(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand Down Expand Up @@ -425,7 +428,7 @@ def dst(
#
@overload
def idst(
x: onp.CanArrayND[np.integer, _ShapeT],
x: CanArrayND[np.integer, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand All @@ -436,7 +439,7 @@ def idst(
) -> onp.Array[_ShapeT, np.float64]: ...
@overload
def idst(
x: onp.CanArrayND[np.float16, _ShapeT],
x: CanArrayND[np.float16, _ShapeT],
type: DCTType = 2,
n: onp.ToInt | None = None,
axis: op.CanIndex = -1,
Expand Down
Loading