Skip to content

Commit 35fdfcc

Browse files
committed
🏷️ Some minor _lib._util improvements
1 parent 462a15d commit 35fdfcc

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

scipy-stubs/_lib/_util.pyi

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,67 @@
11
import multiprocessing.pool as mpp
22
import types
3-
from collections.abc import Callable, Iterable, Mapping, Sequence
4-
from typing import Concatenate, Final, Generic, NamedTuple, TypeAlias, overload
5-
from typing_extensions import TypeVar, override
3+
from collections.abc import Callable, Iterable, Sequence
4+
from typing import Any, Concatenate, Final, Generic, Literal, NamedTuple, TypeAlias, overload
5+
from typing_extensions import Never, TypeVar, override
66

77
import numpy as np
88
import optype as op
99
import optype.numpy as onp
1010
import optype.numpy.compat as npc
11+
from numpy.random import Generator as Generator # implicit re-export
12+
from optype.numpy.compat import DTypePromotionError as DTypePromotionError # implicit re-export
1113
from scipy._typing import RNG, EnterSelfMixin
1214

13-
_AnyRNG = TypeVar("_AnyRNG", np.random.RandomState, np.random.Generator)
15+
_AnyRNGT = TypeVar("_AnyRNGT", np.random.RandomState, np.random.Generator)
1416

15-
_T = TypeVar("_T", default=object)
16-
_T_co = TypeVar("_T_co", covariant=True, default=object)
17-
_T_contra = TypeVar("_T_contra", contravariant=True, default=object)
1817
_VT = TypeVar("_VT")
1918
_RT = TypeVar("_RT")
19+
20+
_T = TypeVar("_T", default=Any)
21+
_T_co = TypeVar("_T_co", default=Any, covariant=True)
22+
_T_contra = TypeVar("_T_contra", default=Never, contravariant=True)
23+
2024
_AxisT = TypeVar("_AxisT", bound=npc.integer)
2125

2226
###
2327

24-
np_long: Final[type[np.int32 | np.int64]] = ...
25-
np_ulong: Final[type[np.uint32 | np.uint64]] = ...
26-
copy_if_needed: Final[bool | None] = ...
28+
np_long: Final[type[np.int32 | np.int64]] = ... # `np.long` on `numpy>=2`, else `np.int_`
29+
np_ulong: Final[type[np.uint32 | np.uint64]] = ... # `np.ulong` on `numpy>=2`, else `np.uint`
30+
copy_if_needed: Final[Literal[False] | None] = ... # `None` on `numpy>=2`, otherwise `False`
2731

32+
# NOTE: These aliases are implictly exported at runtime
2833
IntNumber: TypeAlias = int | npc.integer
2934
DecimalNumber: TypeAlias = float | npc.floating | npc.integer
30-
3135
_RNG: TypeAlias = np.random.Generator | np.random.RandomState
3236
SeedType: TypeAlias = IntNumber | _RNG | None
33-
# NOTE: This is actually a exported at runtime :(
34-
GeneratorType = TypeVar("GeneratorType", bound=_RNG) # noqa: PYI001
37+
GeneratorType = TypeVar("GeneratorType", bound=_RNG) # noqa: PYI001 # oof
38+
39+
###
3540

3641
class ComplexWarning(RuntimeWarning): ...
3742
class VisibleDeprecationWarning(UserWarning): ...
38-
class DTypePromotionError(TypeError): ...
3943

4044
class AxisError(ValueError, IndexError):
4145
_msg: Final[str | None]
4246
axis: Final[int | None]
43-
ndim: Final[int | None]
44-
47+
ndim: Final[onp.NDim | None]
4548
@overload
4649
def __init__(self, /, axis: str, ndim: None = None, msg_prefix: None = None) -> None: ...
4750
@overload
48-
def __init__(self, /, axis: int, ndim: int, msg_prefix: str | None = None) -> None: ...
51+
def __init__(self, /, axis: int, ndim: onp.NDim, msg_prefix: str | None = None) -> None: ...
4952

5053
class FullArgSpec(NamedTuple):
51-
args: Sequence[str]
54+
args: list[str]
5255
varargs: str | None
5356
varkw: str | None
54-
defaults: tuple[object, ...] | None
55-
kwonlyargs: Sequence[str]
56-
kwonlydefaults: Mapping[str, object] | None
57-
annotations: Mapping[str, type | object | str]
57+
defaults: tuple[Any, ...] | None
58+
kwonlyargs: list[str]
59+
kwonlydefaults: dict[str, Any] | None
60+
annotations: dict[str, Any]
5861

5962
class _FunctionWrapper(Generic[_T_contra, _T_co]):
6063
f: Callable[Concatenate[_T_contra, ...], _T_co]
61-
args: tuple[object, ...]
62-
64+
args: tuple[Any, ...]
6365
@overload
6466
def __init__(self, /, f: Callable[[_T_contra], _T_co], args: tuple[()]) -> None: ...
6567
@overload
@@ -69,8 +71,8 @@ class _FunctionWrapper(Generic[_T_contra, _T_co]):
6971
class MapWrapper(EnterSelfMixin):
7072
pool: int | mpp.Pool | None
7173

72-
def __init__(self, /, pool: Callable[[Callable[[_VT], _RT], Iterable[_VT]], Sequence[_RT]] | int = 1) -> None: ...
73-
def __call__(self, /, func: Callable[[_VT], _RT], iterable: Iterable[_VT]) -> Sequence[_RT]: ...
74+
def __init__(self, /, pool: Callable[[Callable[[_VT], _RT], Iterable[_VT]], Iterable[_RT]] | int = 1) -> None: ...
75+
def __call__(self, /, func: Callable[[_VT], _RT], iterable: Iterable[_VT]) -> Iterable[_RT]: ...
7476
def terminate(self, /) -> None: ...
7577
def join(self, /) -> None: ...
7678
def close(self, /) -> None: ...
@@ -81,23 +83,23 @@ class _RichResult(dict[str, _T]):
8183
def __setattr__(self, name: str, value: _T, /) -> None: ...
8284

8385
#
84-
def float_factorial(n: int) -> float: ...
86+
def float_factorial(n: op.CanIndex) -> float: ... # will be `np.inf` if `n >= 171`
8587

8688
#
8789
def getfullargspec_no_self(func: Callable[..., object]) -> FullArgSpec: ...
8890

8991
#
9092
@overload
91-
def check_random_state(seed: _AnyRNG) -> _AnyRNG: ...
93+
def check_random_state(seed: _AnyRNGT) -> _AnyRNGT: ...
9294
@overload
93-
def check_random_state(seed: int | npc.integer | types.ModuleType | None) -> np.random.RandomState: ...
95+
def check_random_state(seed: onp.ToJustInt | types.ModuleType | None) -> np.random.RandomState: ...
9496

9597
#
9698
@overload
9799
def rng_integers(
98100
gen: RNG | None,
99-
low: onp.ToInt | onp.ToIntND,
100-
high: onp.ToInt | onp.ToIntND | None = None,
101+
low: onp.ToInt,
102+
high: onp.ToInt | None = None,
101103
size: tuple[()] | None = None,
102104
dtype: onp.AnyIntegerDType = "int64",
103105
endpoint: op.CanBool = False,
@@ -114,8 +116,8 @@ def rng_integers(
114116

115117
#
116118
@overload
117-
def normalize_axis_index(axis: int, ndim: int) -> int: ...
119+
def normalize_axis_index(axis: int, ndim: onp.NDim) -> onp.NDim: ...
118120
@overload
119-
def normalize_axis_index(axis: int, ndim: _AxisT) -> _AxisT: ...
121+
def normalize_axis_index(axis: int | _AxisT, ndim: _AxisT) -> _AxisT: ...
120122
@overload
121-
def normalize_axis_index(axis: _AxisT, ndim: int | _AxisT) -> _AxisT: ...
123+
def normalize_axis_index(axis: _AxisT, ndim: onp.NDim | _AxisT) -> _AxisT: ...

0 commit comments

Comments
 (0)