Skip to content

Commit 60f64f3

Browse files
committed
add type guards
1 parent 4390b7f commit 60f64f3

File tree

5 files changed

+112
-43
lines changed

5 files changed

+112
-43
lines changed

.ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,6 @@ exclude = [
545545
"sphinx/util/cfamily.py",
546546
"sphinx/util/math.py",
547547
"sphinx/util/logging.py",
548-
"sphinx/util/inspect.py",
549548
"sphinx/util/parallel.py",
550549
"sphinx/util/inventory.py",
551550
"sphinx/util/__init__.py",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ lint = [
8787
"sphinx-lint",
8888
"types-docutils",
8989
"types-requests",
90+
"typing-extensions",
9091
"importlib_metadata", # for mypy (Python<=3.9)
9192
"tomli", # for mypy (Python<=3.10)
9293
"pytest>=6.0",

sphinx/ext/autodoc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2267,7 +2267,7 @@ def format_signature(self, **kwargs: Any) -> str:
22672267
pass # default implementation. skipped.
22682268
else:
22692269
if inspect.isclassmethod(func):
2270-
func = func.__func__
2270+
func = func.__func__ # type: ignore[attr-defined]
22712271
dispatchmeth = self.annotate_to_first_argument(func, typ)
22722272
if dispatchmeth:
22732273
documenter = MethodDocumenter(self.directive, '')

sphinx/util/inspect.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import ast
66
import builtins
77
import contextlib
8-
import enum
98
import inspect
109
import re
1110
import sys
1211
import types
1312
import typing
1413
from collections.abc import Mapping
14+
from enum import Enum
1515
from functools import cached_property, partial, partialmethod, singledispatchmethod
1616
from importlib import import_module
1717
from inspect import Parameter, Signature
@@ -26,8 +26,44 @@
2626
if TYPE_CHECKING:
2727
from collections.abc import Callable, Sequence
2828
from inspect import _ParameterKind
29-
from types import MethodType, ModuleType
30-
from typing import Final
29+
from typing import Final, Protocol, Union
30+
31+
from typing_extensions import TypeGuard
32+
33+
class _SupportsGet(Protocol):
34+
def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704
35+
36+
class _SupportsSet(Protocol):
37+
# instance and value are contravariants but we do not need that precision
38+
def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704
39+
40+
class _SupportsDelete(Protocol):
41+
# instance is contravariant but we do not need that precision
42+
def __delete__(self, __instance: Any) -> None: ... # NoQA: E704
43+
44+
class _GenericAliasLike(Protocol):
45+
# Minimalist interface for a generic alias in typing.py.
46+
47+
# The ``__origin__`` is defined both on the base type for generics
48+
# in typing.py *and* on the public ``types.GenericAlias`` type.
49+
__origin__: type
50+
51+
# Note that special generic alias types (tuple and callables) do
52+
# not directly define ``__args__``. At runtime, however, they are
53+
# actually instances of ``typing._GenericAlias`` which does have
54+
# an ``__args__`` field.
55+
__args__: tuple[Any, ...]
56+
57+
_RoutineType = Union[
58+
types.FunctionType,
59+
types.LambdaType,
60+
types.MethodType,
61+
types.BuiltinFunctionType,
62+
types.BuiltinMethodType,
63+
types.WrapperDescriptorType,
64+
types.MethodDescriptorType,
65+
types.ClassMethodDescriptorType,
66+
]
3167

3268
logger = logging.getLogger(__name__)
3369

@@ -90,7 +126,7 @@ class methods and static methods.
90126

91127

92128
def getall(obj: Any) -> Sequence[str] | None:
93-
"""Get the ``__all__`` attribute of an object as sequence.
129+
"""Get the ``__all__`` attribute of an object as a sequence.
94130
95131
This returns ``None`` if the given ``obj.__all__`` does not exist and
96132
raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of
@@ -184,14 +220,14 @@ def isNewType(obj: Any) -> bool:
184220
return __module__ == 'typing' and __qualname__ == 'NewType.<locals>.new_type'
185221

186222

187-
def isenumclass(x: Any) -> bool:
223+
def isenumclass(x: Any) -> TypeGuard[type[Enum]]:
188224
"""Check if the object is an :class:`enumeration class <enum.Enum>`."""
189-
return isclass(x) and issubclass(x, enum.Enum)
225+
return isclass(x) and issubclass(x, Enum)
190226

191227

192-
def isenumattribute(x: Any) -> bool:
228+
def isenumattribute(x: Any) -> TypeGuard[Enum]:
193229
"""Check if the object is an enumeration attribute."""
194-
return isinstance(x, enum.Enum)
230+
return isinstance(x, Enum)
195231

196232

197233
def unpartial(obj: Any) -> Any:
@@ -206,7 +242,7 @@ def unpartial(obj: Any) -> Any:
206242
return obj
207243

208244

209-
def ispartial(obj: Any) -> bool:
245+
def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]:
210246
"""Check if the object is a partial function or method."""
211247
return isinstance(obj, (partial, partialmethod))
212248

@@ -239,7 +275,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
239275
return False
240276

241277

242-
def isdescriptor(x: Any) -> bool:
278+
def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]:
243279
"""Check if the object is a :external+python:term:`descriptor`."""
244280
return any(
245281
callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__')
@@ -251,7 +287,7 @@ def isabstractmethod(obj: Any) -> bool:
251287
return safe_getattr(obj, '__isabstractmethod__', False) is True
252288

253289

254-
def isboundmethod(method: MethodType) -> bool:
290+
def isboundmethod(method: types.MethodType) -> bool:
255291
"""Check if the method is a bound method."""
256292
return safe_getattr(method, '__self__', None) is not None
257293

@@ -306,12 +342,12 @@ def is_singledispatch_function(obj: Any) -> bool:
306342
)
307343

308344

309-
def is_singledispatch_method(obj: Any) -> bool:
345+
def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]:
310346
"""Check if the object is a :class:`~functools.singledispatchmethod`."""
311347
return isinstance(obj, singledispatchmethod)
312348

313349

314-
def isfunction(obj: Any) -> bool:
350+
def isfunction(obj: Any) -> TypeGuard[types.FunctionType]:
315351
"""Check if the object is a user-defined function.
316352
317353
Partial objects are unwrapped before checking them.
@@ -321,7 +357,7 @@ def isfunction(obj: Any) -> bool:
321357
return inspect.isfunction(unpartial(obj))
322358

323359

324-
def isbuiltin(obj: Any) -> bool:
360+
def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]:
325361
"""Check if the object is a built-in function or method.
326362
327363
Partial objects are unwrapped before checking them.
@@ -331,7 +367,7 @@ def isbuiltin(obj: Any) -> bool:
331367
return inspect.isbuiltin(unpartial(obj))
332368

333369

334-
def isroutine(obj: Any) -> bool:
370+
def isroutine(obj: Any) -> TypeGuard[_RoutineType]:
335371
"""Check if the object is a kind of function or method.
336372
337373
Partial objects are unwrapped before checking them.
@@ -356,7 +392,7 @@ def _is_wrapped_coroutine(obj: Any) -> bool:
356392
return hasattr(obj, '__wrapped__')
357393

358394

359-
def isproperty(obj: Any) -> bool:
395+
def isproperty(obj: Any) -> TypeGuard[property | cached_property]:
360396
"""Check if the object is property (possibly cached)."""
361397
return isinstance(obj, (property, cached_property))
362398

@@ -431,8 +467,8 @@ def object_description(obj: Any, *, _seen: frozenset[int] = frozenset()) -> str:
431467
return 'frozenset({%s})' % ', '.join(
432468
object_description(x, _seen=seen) for x in sorted_values
433469
)
434-
elif isinstance(obj, enum.Enum):
435-
if obj.__repr__.__func__ is not enum.Enum.__repr__: # type: ignore[attr-defined]
470+
elif isinstance(obj, Enum):
471+
if obj.__repr__.__func__ is not Enum.__repr__: # type: ignore[attr-defined]
436472
return repr(obj)
437473
return f'{obj.__class__.__name__}.{obj.name}'
438474
elif isinstance(obj, tuple):
@@ -527,7 +563,7 @@ def __init__(self, modname: str, mapping: Mapping[str, str]) -> None:
527563
self.__modname = modname
528564
self.__mapping = mapping
529565

530-
self.__module: ModuleType | None = None
566+
self.__module: types.ModuleType | None = None
531567

532568
def __getattr__(self, name: str) -> Any:
533569
fullname = '.'.join(filter(None, [self.__modname, name]))

sphinx/util/typing.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,32 @@
88
from collections.abc import Sequence
99
from contextvars import Context, ContextVar, Token
1010
from struct import Struct
11-
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypedDict, TypeVar, Union
11+
from typing import (
12+
TYPE_CHECKING,
13+
Annotated,
14+
Any,
15+
Callable,
16+
ForwardRef,
17+
TypedDict,
18+
TypeVar,
19+
Union,
20+
)
1221

1322
from docutils import nodes
1423
from docutils.parsers.rst.states import Inliner
1524

1625
if TYPE_CHECKING:
1726
from collections.abc import Mapping
18-
from typing import Final
27+
from typing import Final, Protocol
28+
29+
from typing_extensions import TypeGuard
1930

2031
from sphinx.application import Sphinx
2132

33+
class _SpecialFormInterface(Protocol):
34+
_name: str
35+
36+
2237
if sys.version_info >= (3, 10):
2338
from types import UnionType
2439
else:
@@ -152,8 +167,27 @@ def is_system_TypeVar(typ: Any) -> bool:
152167
return modname == 'typing' and isinstance(typ, TypeVar)
153168

154169

155-
def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> str:
156-
"""Convert python class to a reST reference.
170+
def _is_special_form(obj: Any) -> TypeGuard[_SpecialFormInterface]:
171+
"""Check if *obj* is a typing special form.
172+
173+
The guarded type is a protocol with the members that Sphinx needs in
174+
this module and not the native ``typing._SpecialForm`` from typeshed,
175+
but the runtime type of *obj* must be a true special form instance.
176+
"""
177+
return isinstance(obj, typing._SpecialForm)
178+
179+
180+
def _is_annotated_form(obj: Any) -> TypeGuard[Annotated[Any, ...]]:
181+
"""Check if *obj* is an annotated type."""
182+
return typing.get_origin(obj) is Annotated or str(obj).startswith('typing.Annotated')
183+
184+
185+
def _get_typing_internal_name(obj: Any) -> str | None:
186+
return getattr(obj, '_name', None)
187+
188+
189+
def restify(cls: Any, mode: str = 'fully-qualified-except-typing') -> str:
190+
"""Convert a python type-like object to a reST reference.
157191
158192
:param mode: Specify a method how annotations will be stringified.
159193
@@ -205,10 +239,10 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
205239
elif (
206240
inspect.isgenericalias(cls)
207241
and cls.__module__ == 'typing'
208-
and cls.__origin__ is Union # type: ignore[attr-defined]
242+
and cls.__origin__ is Union
209243
):
210-
# *cls* is defined in ``typing``, and thus ``__args__`` must exist;
211-
if NoneType in (__args__ := cls.__args__): # type: ignore[attr-defined]
244+
# *cls* is defined in ``typing``, thus ``__args__`` should exist
245+
if NoneType in (__args__ := cls.__args__):
212246
# Shape: Union[T_1, ..., T_k, None, T_{k+1}, ..., T_n]
213247
#
214248
# Note that we keep Literal[None] in their rightful place
@@ -226,33 +260,33 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
226260
args = ', '.join(restify(a, mode) for a in __args__)
227261
return rf':py:obj:`~typing.Union`\ [{args}]'
228262
elif inspect.isgenericalias(cls):
229-
if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined]
230-
text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type]
231-
elif getattr(cls, '_name', None):
232-
cls_name = cls._name # type: ignore[attr-defined]
263+
__origin__ = cls.__origin__
264+
if _is_annotated_form(__origin__):
265+
text = restify(__origin__, mode)
266+
elif internal_name := _get_typing_internal_name(cls):
233267
if cls.__module__ == 'typing':
234-
text = f':py:class:`~{cls.__module__}.{cls_name}`'
268+
text = f':py:class:`~{cls.__module__}.{internal_name}`'
235269
else:
236-
text = f':py:class:`{modprefix}{cls.__module__}.{cls_name}`'
270+
text = f':py:class:`{modprefix}{cls.__module__}.{internal_name}`'
237271
else:
238-
text = restify(cls.__origin__, mode) # type: ignore[attr-defined]
272+
text = restify(__origin__, mode)
239273

240274
origin = getattr(cls, '__origin__', None)
241275
if not hasattr(cls, '__args__'): # NoQA: SIM114
242276
pass
243277
elif all(is_system_TypeVar(a) for a in cls.__args__):
244278
# Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT])
245279
pass
246-
elif cls.__module__ == 'typing' and cls._name == 'Callable': # type: ignore[attr-defined]
280+
elif cls.__module__ == 'typing' and _get_typing_internal_name(cls) == 'Callable':
247281
args = ', '.join(restify(a, mode) for a in cls.__args__[:-1])
248282
text += rf'\ [[{args}], {restify(cls.__args__[-1], mode)}]'
249-
elif cls.__module__ == 'typing' and getattr(origin, '_name', None) == 'Literal':
283+
elif cls.__module__ == 'typing' and _get_typing_internal_name(origin) == 'Literal':
250284
literals = ', '.join(_restify_literal_arg(a, mode) for a in cls.__args__)
251285
text += rf'\ [{literals}]'
252286
elif cls.__args__:
253287
text += rf'\ [{", ".join(restify(a, mode) for a in cls.__args__)}]'
254288
return text
255-
elif isinstance(cls, typing._SpecialForm):
289+
elif _is_special_form(cls):
256290
return f':py:obj:`~{cls.__module__}.{cls._name}`'
257291
elif sys.version_info[:2] >= (3, 11) and cls is typing.Any:
258292
# handle bpo-46998
@@ -332,7 +366,7 @@ def stringify_annotation(
332366
return f'{module_prefix}{annotation_module}.{annotation_name}'
333367
elif is_invalid_builtin_class(annotation):
334368
return f'{module_prefix}{_INVALID_BUILTIN_CLASSES[annotation]}'
335-
elif str(annotation).startswith('typing.Annotated'): # for py39+
369+
elif _is_annotated_form(annotation): # for py39+
336370
pass
337371
elif annotation_module == 'builtins' and annotation_qualname:
338372
if (args := getattr(annotation, '__args__', None)) is None:
@@ -361,9 +395,8 @@ def stringify_annotation(
361395
# handle ForwardRefs
362396
qualname = annotation_forward_arg
363397
else:
364-
_name = getattr(annotation, '_name', '')
365-
if _name:
366-
qualname = _name
398+
if internal_name := _get_typing_internal_name(annotation):
399+
qualname = internal_name
367400
elif annotation_qualname:
368401
qualname = annotation_qualname
369402
else:
@@ -397,7 +430,7 @@ def stringify_annotation(
397430
elif qualname == 'Literal':
398431
literals = ', '.join(_stringify_literal_arg(a, mode) for a in annotation_args)
399432
return f'{module_prefix}Literal[{literals}]'
400-
elif str(annotation).startswith('typing.Annotated'): # for py39+
433+
elif _is_annotated_form(annotation): # for py39+
401434
return stringify_annotation(annotation_args[0], mode)
402435
elif all(is_system_TypeVar(a) for a in annotation_args):
403436
# Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT])

0 commit comments

Comments
 (0)