From a45813dc598bc044e2af3b05d586ba3d8a35e79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sun, 14 Apr 2024 10:39:03 +0200 Subject: [PATCH 1/3] add type guards --- .ruff.toml | 1 - pyproject.toml | 1 + sphinx/ext/autodoc/__init__.py | 2 +- sphinx/util/inspect.py | 74 +++++++++++++++++++++++++--------- sphinx/util/typing.py | 6 ++- 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index 4a19e8ed4e4..d0834b76959 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -553,7 +553,6 @@ exclude = [ "sphinx/util/cfamily.py", "sphinx/util/math.py", "sphinx/util/logging.py", - "sphinx/util/inspect.py", "sphinx/util/parallel.py", "sphinx/util/inventory.py", "sphinx/util/__init__.py", diff --git a/pyproject.toml b/pyproject.toml index 8e8dac4a4f8..9be8d58e7fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ lint = [ "sphinx-lint", "types-docutils", "types-requests", + "typing-extensions", "importlib_metadata", # for mypy (Python<=3.9) "tomli", # for mypy (Python<=3.10) "pytest>=6.0", diff --git a/sphinx/ext/autodoc/__init__.py b/sphinx/ext/autodoc/__init__.py index 16eaa8e75ec..0887876adb5 100644 --- a/sphinx/ext/autodoc/__init__.py +++ b/sphinx/ext/autodoc/__init__.py @@ -2267,7 +2267,7 @@ def format_signature(self, **kwargs: Any) -> str: pass # default implementation. skipped. else: if inspect.isclassmethod(func): - func = func.__func__ + func = func.__func__ # type: ignore[attr-defined] dispatchmeth = self.annotate_to_first_argument(func, typ) if dispatchmeth: documenter = MethodDocumenter(self.directive, '') diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index dfd1d01ef5b..ebc90b8f88c 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -5,13 +5,13 @@ import ast import builtins import contextlib -import enum import inspect import re import sys import types import typing from collections.abc import Mapping +from enum import Enum from functools import cached_property, partial, partialmethod, singledispatchmethod from importlib import import_module from inspect import Parameter, Signature @@ -26,8 +26,44 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence from inspect import _ParameterKind - from types import MethodType, ModuleType - from typing import Final + from typing import Final, Protocol, Union + + from typing_extensions import TypeGuard + + class _SupportsGet(Protocol): + def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704 + + class _SupportsSet(Protocol): + # instance and value are contravariants but we do not need that precision + def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704 + + class _SupportsDelete(Protocol): + # instance is contravariant but we do not need that precision + def __delete__(self, __instance: Any) -> None: ... # NoQA: E704 + + class _GenericAliasLike(Protocol): + # Minimalist interface for a generic alias in typing.py. + + # The ``__origin__`` is defined both on the base type for generics + # in typing.py *and* on the public ``types.GenericAlias`` type. + __origin__: type + + # Note that special generic alias types (tuple and callables) do + # not directly define ``__args__``. At runtime, however, they are + # actually instances of ``typing._GenericAlias`` which does have + # an ``__args__`` field. + __args__: tuple[Any, ...] + + _RoutineType = Union[ + types.FunctionType, + types.LambdaType, + types.MethodType, + types.BuiltinFunctionType, + types.BuiltinMethodType, + types.WrapperDescriptorType, + types.MethodDescriptorType, + types.ClassMethodDescriptorType, + ] logger = logging.getLogger(__name__) @@ -90,7 +126,7 @@ class methods and static methods. def getall(obj: Any) -> Sequence[str] | None: - """Get the ``__all__`` attribute of an object as sequence. + """Get the ``__all__`` attribute of an object as a sequence. This returns ``None`` if the given ``obj.__all__`` does not exist and raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of @@ -184,14 +220,14 @@ def isNewType(obj: Any) -> bool: return __module__ == 'typing' and __qualname__ == 'NewType..new_type' -def isenumclass(x: Any) -> bool: +def isenumclass(x: Any) -> TypeGuard[type[Enum]]: """Check if the object is an :class:`enumeration class `.""" - return isclass(x) and issubclass(x, enum.Enum) + return isclass(x) and issubclass(x, Enum) -def isenumattribute(x: Any) -> bool: +def isenumattribute(x: Any) -> TypeGuard[Enum]: """Check if the object is an enumeration attribute.""" - return isinstance(x, enum.Enum) + return isinstance(x, Enum) def unpartial(obj: Any) -> Any: @@ -206,7 +242,7 @@ def unpartial(obj: Any) -> Any: return obj -def ispartial(obj: Any) -> bool: +def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]: """Check if the object is a partial function or method.""" return isinstance(obj, (partial, partialmethod)) @@ -241,7 +277,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: return False -def isdescriptor(x: Any) -> bool: +def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]: """Check if the object is a :external+python:term:`descriptor`.""" return any( callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__') @@ -253,7 +289,7 @@ def isabstractmethod(obj: Any) -> bool: return safe_getattr(obj, '__isabstractmethod__', False) is True -def isboundmethod(method: MethodType) -> bool: +def isboundmethod(method: types.MethodType) -> bool: """Check if the method is a bound method.""" return safe_getattr(method, '__self__', None) is not None @@ -308,12 +344,12 @@ def is_singledispatch_function(obj: Any) -> bool: ) -def is_singledispatch_method(obj: Any) -> bool: +def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]: """Check if the object is a :class:`~functools.singledispatchmethod`.""" return isinstance(obj, singledispatchmethod) -def isfunction(obj: Any) -> bool: +def isfunction(obj: Any) -> TypeGuard[types.FunctionType]: """Check if the object is a user-defined function. Partial objects are unwrapped before checking them. @@ -323,7 +359,7 @@ def isfunction(obj: Any) -> bool: return inspect.isfunction(unpartial(obj)) -def isbuiltin(obj: Any) -> bool: +def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]: """Check if the object is a built-in function or method. Partial objects are unwrapped before checking them. @@ -333,7 +369,7 @@ def isbuiltin(obj: Any) -> bool: return inspect.isbuiltin(unpartial(obj)) -def isroutine(obj: Any) -> bool: +def isroutine(obj: Any) -> TypeGuard[_RoutineType]: """Check if the object is a kind of function or method. Partial objects are unwrapped before checking them. @@ -358,7 +394,7 @@ def _is_wrapped_coroutine(obj: Any) -> bool: return hasattr(obj, '__wrapped__') -def isproperty(obj: Any) -> bool: +def isproperty(obj: Any) -> TypeGuard[property | cached_property]: """Check if the object is property (possibly cached).""" return isinstance(obj, (property, cached_property)) @@ -433,8 +469,8 @@ def object_description(obj: Any, *, _seen: frozenset[int] = frozenset()) -> str: return 'frozenset({%s})' % ', '.join( object_description(x, _seen=seen) for x in sorted_values ) - elif isinstance(obj, enum.Enum): - if obj.__repr__.__func__ is not enum.Enum.__repr__: # type: ignore[attr-defined] + elif isinstance(obj, Enum): + if obj.__repr__.__func__ is not Enum.__repr__: # type: ignore[attr-defined] return repr(obj) return f'{obj.__class__.__name__}.{obj.name}' elif isinstance(obj, tuple): @@ -529,7 +565,7 @@ def __init__(self, modname: str, mapping: Mapping[str, str]) -> None: self.__modname = modname self.__mapping = mapping - self.__module: ModuleType | None = None + self.__module: types.ModuleType | None = None def __getattr__(self, name: str) -> Any: fullname = '.'.join(filter(None, [self.__modname, name])) diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index 007adca9f25..d7a6e6d85f7 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from collections.abc import Mapping - from typing import Final, Literal + from typing import Final, Literal, Protocol from typing_extensions import TypeAlias @@ -31,6 +31,10 @@ 'smart', ] + class _SpecialFormInterface(Protocol): + _name: str + + if sys.version_info >= (3, 10): from types import UnionType else: From 120dbb71056cd29e99b454df81e191bba880c084 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 05:42:03 +0100 Subject: [PATCH 2/3] Use TypeGuards --- .ruff.toml | 1 + pyproject.toml | 1 - sphinx/ext/autodoc/__init__.py | 2 +- sphinx/ext/autodoc/mock.py | 4 ++- sphinx/util/inspect.py | 66 +++++++++++++++++----------------- sphinx/util/typing.py | 24 ++++++------- 6 files changed, 49 insertions(+), 49 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index d0834b76959..4a19e8ed4e4 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -553,6 +553,7 @@ exclude = [ "sphinx/util/cfamily.py", "sphinx/util/math.py", "sphinx/util/logging.py", + "sphinx/util/inspect.py", "sphinx/util/parallel.py", "sphinx/util/inventory.py", "sphinx/util/__init__.py", diff --git a/pyproject.toml b/pyproject.toml index 9be8d58e7fb..8e8dac4a4f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,6 @@ lint = [ "sphinx-lint", "types-docutils", "types-requests", - "typing-extensions", "importlib_metadata", # for mypy (Python<=3.9) "tomli", # for mypy (Python<=3.10) "pytest>=6.0", diff --git a/sphinx/ext/autodoc/__init__.py b/sphinx/ext/autodoc/__init__.py index 0887876adb5..16eaa8e75ec 100644 --- a/sphinx/ext/autodoc/__init__.py +++ b/sphinx/ext/autodoc/__init__.py @@ -2267,7 +2267,7 @@ def format_signature(self, **kwargs: Any) -> str: pass # default implementation. skipped. else: if inspect.isclassmethod(func): - func = func.__func__ # type: ignore[attr-defined] + func = func.__func__ dispatchmeth = self.annotate_to_first_argument(func, typ) if dispatchmeth: documenter = MethodDocumenter(self.directive, '') diff --git a/sphinx/ext/autodoc/mock.py b/sphinx/ext/autodoc/mock.py index f17c3302cb6..7639c46265b 100644 --- a/sphinx/ext/autodoc/mock.py +++ b/sphinx/ext/autodoc/mock.py @@ -17,6 +17,8 @@ from collections.abc import Iterator, Sequence from typing import Any + from typing_extensions import TypeGuard + logger = logging.getLogger(__name__) @@ -154,7 +156,7 @@ def mock(modnames: list[str]) -> Iterator[None]: finder.invalidate_caches() -def ismockmodule(subject: Any) -> bool: +def ismockmodule(subject: Any) -> TypeGuard[_MockModule]: """Check if the object is a mocked module.""" return isinstance(subject, _MockModule) diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index ebc90b8f88c..632ed0a2b9f 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -5,13 +5,13 @@ import ast import builtins import contextlib +import enum import inspect import re import sys import types import typing from collections.abc import Mapping -from enum import Enum from functools import cached_property, partial, partialmethod, singledispatchmethod from importlib import import_module from inspect import Parameter, Signature @@ -25,10 +25,12 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from enum import Enum from inspect import _ParameterKind + from types import MethodType, ModuleType from typing import Final, Protocol, Union - from typing_extensions import TypeGuard + from typing_extensions import TypeAlias, TypeGuard class _SupportsGet(Protocol): def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704 @@ -41,20 +43,7 @@ class _SupportsDelete(Protocol): # instance is contravariant but we do not need that precision def __delete__(self, __instance: Any) -> None: ... # NoQA: E704 - class _GenericAliasLike(Protocol): - # Minimalist interface for a generic alias in typing.py. - - # The ``__origin__`` is defined both on the base type for generics - # in typing.py *and* on the public ``types.GenericAlias`` type. - __origin__: type - - # Note that special generic alias types (tuple and callables) do - # not directly define ``__args__``. At runtime, however, they are - # actually instances of ``typing._GenericAlias`` which does have - # an ``__args__`` field. - __args__: tuple[Any, ...] - - _RoutineType = Union[ + _RoutineType: TypeAlias = Union[ types.FunctionType, types.LambdaType, types.MethodType, @@ -64,6 +53,11 @@ class _GenericAliasLike(Protocol): types.MethodDescriptorType, types.ClassMethodDescriptorType, ] + _SignatureType: TypeAlias = Union[ + Callable[..., Any], + staticmethod, + classmethod, + ] logger = logging.getLogger(__name__) @@ -126,7 +120,7 @@ class methods and static methods. def getall(obj: Any) -> Sequence[str] | None: - """Get the ``__all__`` attribute of an object as a sequence. + """Get the ``__all__`` attribute of an object as sequence. This returns ``None`` if the given ``obj.__all__`` does not exist and raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of @@ -222,12 +216,12 @@ def isNewType(obj: Any) -> bool: def isenumclass(x: Any) -> TypeGuard[type[Enum]]: """Check if the object is an :class:`enumeration class `.""" - return isclass(x) and issubclass(x, Enum) + return isclass(x) and issubclass(x, enum.Enum) def isenumattribute(x: Any) -> TypeGuard[Enum]: """Check if the object is an enumeration attribute.""" - return isinstance(x, Enum) + return isinstance(x, enum.Enum) def unpartial(obj: Any) -> Any: @@ -247,7 +241,11 @@ def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]: return isinstance(obj, (partial, partialmethod)) -def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: +def isclassmethod( + obj: Any, + cls: Any = None, + name: str | None = None, +) -> TypeGuard[classmethod]: """Check if the object is a :class:`classmethod`.""" if isinstance(obj, classmethod): return True @@ -263,7 +261,11 @@ def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: return False -def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: +def isstaticmethod( + obj: Any, + cls: Any = None, + name: str | None = None, +) -> TypeGuard[staticmethod]: """Check if the object is a :class:`staticmethod`.""" if isinstance(obj, staticmethod): return True @@ -289,7 +291,7 @@ def isabstractmethod(obj: Any) -> bool: return safe_getattr(obj, '__isabstractmethod__', False) is True -def isboundmethod(method: types.MethodType) -> bool: +def isboundmethod(method: MethodType) -> bool: """Check if the method is a bound method.""" return safe_getattr(method, '__self__', None) is not None @@ -379,7 +381,7 @@ def isroutine(obj: Any) -> TypeGuard[_RoutineType]: return inspect.isroutine(unpartial(obj)) -def iscoroutinefunction(obj: Any) -> bool: +def iscoroutinefunction(obj: Any) -> TypeGuard[Callable[..., types.CoroutineType]]: """Check if the object is a :external+python:term:`coroutine` function.""" obj = unwrap_all(obj, stop=_is_wrapped_coroutine) return inspect.iscoroutinefunction(obj) @@ -399,7 +401,7 @@ def isproperty(obj: Any) -> TypeGuard[property | cached_property]: return isinstance(obj, (property, cached_property)) -def isgenericalias(obj: Any) -> bool: +def isgenericalias(obj: Any) -> TypeGuard[types.GenericAlias]: """Check if the object is a generic alias.""" return isinstance(obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined] @@ -469,8 +471,8 @@ def object_description(obj: Any, *, _seen: frozenset[int] = frozenset()) -> str: return 'frozenset({%s})' % ', '.join( object_description(x, _seen=seen) for x in sorted_values ) - elif isinstance(obj, Enum): - if obj.__repr__.__func__ is not Enum.__repr__: # type: ignore[attr-defined] + elif isinstance(obj, enum.Enum): + if obj.__repr__.__func__ is not enum.Enum.__repr__: # type: ignore[attr-defined] return repr(obj) return f'{obj.__class__.__name__}.{obj.name}' elif isinstance(obj, tuple): @@ -565,7 +567,7 @@ def __init__(self, modname: str, mapping: Mapping[str, str]) -> None: self.__modname = modname self.__mapping = mapping - self.__module: types.ModuleType | None = None + self.__module: ModuleType | None = None def __getattr__(self, name: str) -> Any: fullname = '.'.join(filter(None, [self.__modname, name])) @@ -615,7 +617,7 @@ def __getitem__(self, key: str) -> Any: raise KeyError -def _should_unwrap(subject: Callable[..., Any]) -> bool: +def _should_unwrap(subject: _SignatureType) -> bool: """Check the function should be unwrapped on getting signature.""" __globals__ = getglobals(subject) # contextmanger should be unwrapped @@ -626,7 +628,7 @@ def _should_unwrap(subject: Callable[..., Any]) -> bool: def signature( - subject: Callable[..., Any], + subject: _SignatureType, bound_method: bool = False, type_aliases: Mapping[str, str] | None = None, ) -> Signature: @@ -639,12 +641,12 @@ def signature( try: if _should_unwrap(subject): - signature = inspect.signature(subject) + signature = inspect.signature(subject) # type: ignore[arg-type] else: - signature = inspect.signature(subject, follow_wrapped=True) + signature = inspect.signature(subject, follow_wrapped=True) # type: ignore[arg-type] except ValueError: # follow built-in wrappers up (ex. functools.lru_cache) - signature = inspect.signature(subject) + signature = inspect.signature(subject) # type: ignore[arg-type] parameters = list(signature.parameters.values()) return_annotation = signature.return_annotation diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index d7a6e6d85f7..39056f91b44 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from collections.abc import Mapping - from typing import Final, Literal, Protocol + from typing import Final, Literal from typing_extensions import TypeAlias @@ -31,10 +31,6 @@ 'smart', ] - class _SpecialFormInterface(Protocol): - _name: str - - if sys.version_info >= (3, 10): from types import UnionType else: @@ -168,7 +164,7 @@ def is_system_TypeVar(typ: Any) -> bool: return modname == 'typing' and isinstance(typ, TypeVar) -def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typing') -> str: +def restify(cls: Any, mode: _RestifyMode = 'fully-qualified-except-typing') -> str: """Convert python class to a reST reference. :param mode: Specify a method how annotations will be stringified. @@ -233,17 +229,17 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin return f':py:class:`{cls.__name__}`' elif (inspect.isgenericalias(cls) and cls.__module__ == 'typing' - and cls.__origin__ is Union): # type: ignore[attr-defined] + and cls.__origin__ is Union): # *cls* is defined in ``typing``, and thus ``__args__`` must exist - return ' | '.join(restify(a, mode) for a in cls.__args__) # type: ignore[attr-defined] + return ' | '.join(restify(a, mode) for a in cls.__args__) elif inspect.isgenericalias(cls): - if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined] - text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type] + if isinstance(cls.__origin__, typing._SpecialForm): + text = restify(cls.__origin__, mode) elif getattr(cls, '_name', None): - cls_name = cls._name # type: ignore[attr-defined] + cls_name = cls._name text = f':py:class:`{modprefix}{cls.__module__}.{cls_name}`' else: - text = restify(cls.__origin__, mode) # type: ignore[attr-defined] + text = restify(cls.__origin__, mode) origin = getattr(cls, '__origin__', None) if not hasattr(cls, '__args__'): # NoQA: SIM114 @@ -251,7 +247,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin elif all(is_system_TypeVar(a) for a in cls.__args__): # Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT]) pass - elif cls.__module__ == 'typing' and cls._name == 'Callable': # type: ignore[attr-defined] + elif cls.__module__ == 'typing' and cls._name == 'Callable': args = ', '.join(restify(a, mode) for a in cls.__args__[:-1]) text += fr'\ [[{args}], {restify(cls.__args__[-1], mode)}]' elif cls.__module__ == 'typing' and getattr(origin, '_name', None) == 'Literal': @@ -263,7 +259,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin return text elif isinstance(cls, typing._SpecialForm): - return f':py:obj:`~{cls.__module__}.{cls._name}`' + return f':py:obj:`~{cls.__module__}.{cls._name}`' # type: ignore[attr-defined] elif sys.version_info[:2] >= (3, 11) and cls is typing.Any: # handle bpo-46998 return f':py:obj:`~{cls.__module__}.{cls.__name__}`' From 603f2d0dfc9abecd60909a67e0b5506dc715d8c2 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 05:52:43 +0100 Subject: [PATCH 3/3] more --- sphinx/util/inspect.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index 632ed0a2b9f..da487a05a59 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -25,7 +25,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from enum import Enum from inspect import _ParameterKind from types import MethodType, ModuleType from typing import Final, Protocol, Union @@ -120,7 +119,7 @@ class methods and static methods. def getall(obj: Any) -> Sequence[str] | None: - """Get the ``__all__`` attribute of an object as sequence. + """Get the ``__all__`` attribute of an object as a sequence. This returns ``None`` if the given ``obj.__all__`` does not exist and raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of @@ -214,12 +213,12 @@ def isNewType(obj: Any) -> bool: return __module__ == 'typing' and __qualname__ == 'NewType..new_type' -def isenumclass(x: Any) -> TypeGuard[type[Enum]]: +def isenumclass(x: Any) -> TypeGuard[type[enum.Enum]]: """Check if the object is an :class:`enumeration class `.""" return isclass(x) and issubclass(x, enum.Enum) -def isenumattribute(x: Any) -> TypeGuard[Enum]: +def isenumattribute(x: Any) -> TypeGuard[enum.Enum]: """Check if the object is an enumeration attribute.""" return isinstance(x, enum.Enum)