Skip to content

Commit a45813d

Browse files
picnixzAA-Turner
authored andcommitted
add type guards
1 parent b6948b8 commit a45813d

File tree

5 files changed

+62
-22
lines changed

5 files changed

+62
-22
lines changed

.ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ exclude = [
553553
"sphinx/util/cfamily.py",
554554
"sphinx/util/math.py",
555555
"sphinx/util/logging.py",
556-
"sphinx/util/inspect.py",
557556
"sphinx/util/parallel.py",
558557
"sphinx/util/inventory.py",
559558
"sphinx/util/__init__.py",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ lint = [
8888
"sphinx-lint",
8989
"types-docutils",
9090
"types-requests",
91+
"typing-extensions",
9192
"importlib_metadata", # for mypy (Python<=3.9)
9293
"tomli", # for mypy (Python<=3.10)
9394
"pytest>=6.0",

sphinx/ext/autodoc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2267,7 +2267,7 @@ def format_signature(self, **kwargs: Any) -> str:
22672267
pass # default implementation. skipped.
22682268
else:
22692269
if inspect.isclassmethod(func):
2270-
func = func.__func__
2270+
func = func.__func__ # type: ignore[attr-defined]
22712271
dispatchmeth = self.annotate_to_first_argument(func, typ)
22722272
if dispatchmeth:
22732273
documenter = MethodDocumenter(self.directive, '')

sphinx/util/inspect.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import ast
66
import builtins
77
import contextlib
8-
import enum
98
import inspect
109
import re
1110
import sys
1211
import types
1312
import typing
1413
from collections.abc import Mapping
14+
from enum import Enum
1515
from functools import cached_property, partial, partialmethod, singledispatchmethod
1616
from importlib import import_module
1717
from inspect import Parameter, Signature
@@ -26,8 +26,44 @@
2626
if TYPE_CHECKING:
2727
from collections.abc import Callable, Sequence
2828
from inspect import _ParameterKind
29-
from types import MethodType, ModuleType
30-
from typing import Final
29+
from typing import Final, Protocol, Union
30+
31+
from typing_extensions import TypeGuard
32+
33+
class _SupportsGet(Protocol):
34+
def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704
35+
36+
class _SupportsSet(Protocol):
37+
# instance and value are contravariants but we do not need that precision
38+
def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704
39+
40+
class _SupportsDelete(Protocol):
41+
# instance is contravariant but we do not need that precision
42+
def __delete__(self, __instance: Any) -> None: ... # NoQA: E704
43+
44+
class _GenericAliasLike(Protocol):
45+
# Minimalist interface for a generic alias in typing.py.
46+
47+
# The ``__origin__`` is defined both on the base type for generics
48+
# in typing.py *and* on the public ``types.GenericAlias`` type.
49+
__origin__: type
50+
51+
# Note that special generic alias types (tuple and callables) do
52+
# not directly define ``__args__``. At runtime, however, they are
53+
# actually instances of ``typing._GenericAlias`` which does have
54+
# an ``__args__`` field.
55+
__args__: tuple[Any, ...]
56+
57+
_RoutineType = Union[
58+
types.FunctionType,
59+
types.LambdaType,
60+
types.MethodType,
61+
types.BuiltinFunctionType,
62+
types.BuiltinMethodType,
63+
types.WrapperDescriptorType,
64+
types.MethodDescriptorType,
65+
types.ClassMethodDescriptorType,
66+
]
3167

3268
logger = logging.getLogger(__name__)
3369

@@ -90,7 +126,7 @@ class methods and static methods.
90126

91127

92128
def getall(obj: Any) -> Sequence[str] | None:
93-
"""Get the ``__all__`` attribute of an object as sequence.
129+
"""Get the ``__all__`` attribute of an object as a sequence.
94130
95131
This returns ``None`` if the given ``obj.__all__`` does not exist and
96132
raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of
@@ -184,14 +220,14 @@ def isNewType(obj: Any) -> bool:
184220
return __module__ == 'typing' and __qualname__ == 'NewType.<locals>.new_type'
185221

186222

187-
def isenumclass(x: Any) -> bool:
223+
def isenumclass(x: Any) -> TypeGuard[type[Enum]]:
188224
"""Check if the object is an :class:`enumeration class <enum.Enum>`."""
189-
return isclass(x) and issubclass(x, enum.Enum)
225+
return isclass(x) and issubclass(x, Enum)
190226

191227

192-
def isenumattribute(x: Any) -> bool:
228+
def isenumattribute(x: Any) -> TypeGuard[Enum]:
193229
"""Check if the object is an enumeration attribute."""
194-
return isinstance(x, enum.Enum)
230+
return isinstance(x, Enum)
195231

196232

197233
def unpartial(obj: Any) -> Any:
@@ -206,7 +242,7 @@ def unpartial(obj: Any) -> Any:
206242
return obj
207243

208244

209-
def ispartial(obj: Any) -> bool:
245+
def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]:
210246
"""Check if the object is a partial function or method."""
211247
return isinstance(obj, (partial, partialmethod))
212248

@@ -241,7 +277,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
241277
return False
242278

243279

244-
def isdescriptor(x: Any) -> bool:
280+
def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]:
245281
"""Check if the object is a :external+python:term:`descriptor`."""
246282
return any(
247283
callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__')
@@ -253,7 +289,7 @@ def isabstractmethod(obj: Any) -> bool:
253289
return safe_getattr(obj, '__isabstractmethod__', False) is True
254290

255291

256-
def isboundmethod(method: MethodType) -> bool:
292+
def isboundmethod(method: types.MethodType) -> bool:
257293
"""Check if the method is a bound method."""
258294
return safe_getattr(method, '__self__', None) is not None
259295

@@ -308,12 +344,12 @@ def is_singledispatch_function(obj: Any) -> bool:
308344
)
309345

310346

311-
def is_singledispatch_method(obj: Any) -> bool:
347+
def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]:
312348
"""Check if the object is a :class:`~functools.singledispatchmethod`."""
313349
return isinstance(obj, singledispatchmethod)
314350

315351

316-
def isfunction(obj: Any) -> bool:
352+
def isfunction(obj: Any) -> TypeGuard[types.FunctionType]:
317353
"""Check if the object is a user-defined function.
318354
319355
Partial objects are unwrapped before checking them.
@@ -323,7 +359,7 @@ def isfunction(obj: Any) -> bool:
323359
return inspect.isfunction(unpartial(obj))
324360

325361

326-
def isbuiltin(obj: Any) -> bool:
362+
def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]:
327363
"""Check if the object is a built-in function or method.
328364
329365
Partial objects are unwrapped before checking them.
@@ -333,7 +369,7 @@ def isbuiltin(obj: Any) -> bool:
333369
return inspect.isbuiltin(unpartial(obj))
334370

335371

336-
def isroutine(obj: Any) -> bool:
372+
def isroutine(obj: Any) -> TypeGuard[_RoutineType]:
337373
"""Check if the object is a kind of function or method.
338374
339375
Partial objects are unwrapped before checking them.
@@ -358,7 +394,7 @@ def _is_wrapped_coroutine(obj: Any) -> bool:
358394
return hasattr(obj, '__wrapped__')
359395

360396

361-
def isproperty(obj: Any) -> bool:
397+
def isproperty(obj: Any) -> TypeGuard[property | cached_property]:
362398
"""Check if the object is property (possibly cached)."""
363399
return isinstance(obj, (property, cached_property))
364400

@@ -433,8 +469,8 @@ def object_description(obj: Any, *, _seen: frozenset[int] = frozenset()) -> str:
433469
return 'frozenset({%s})' % ', '.join(
434470
object_description(x, _seen=seen) for x in sorted_values
435471
)
436-
elif isinstance(obj, enum.Enum):
437-
if obj.__repr__.__func__ is not enum.Enum.__repr__: # type: ignore[attr-defined]
472+
elif isinstance(obj, Enum):
473+
if obj.__repr__.__func__ is not Enum.__repr__: # type: ignore[attr-defined]
438474
return repr(obj)
439475
return f'{obj.__class__.__name__}.{obj.name}'
440476
elif isinstance(obj, tuple):
@@ -529,7 +565,7 @@ def __init__(self, modname: str, mapping: Mapping[str, str]) -> None:
529565
self.__modname = modname
530566
self.__mapping = mapping
531567

532-
self.__module: ModuleType | None = None
568+
self.__module: types.ModuleType | None = None
533569

534570
def __getattr__(self, name: str) -> Any:
535571
fullname = '.'.join(filter(None, [self.__modname, name]))

sphinx/util/typing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from collections.abc import Mapping
18-
from typing import Final, Literal
18+
from typing import Final, Literal, Protocol
1919

2020
from typing_extensions import TypeAlias
2121

@@ -31,6 +31,10 @@
3131
'smart',
3232
]
3333

34+
class _SpecialFormInterface(Protocol):
35+
_name: str
36+
37+
3438
if sys.version_info >= (3, 10):
3539
from types import UnionType
3640
else:

0 commit comments

Comments
 (0)