Skip to content

Commit acc92ff

Browse files
picnixzAA-Turner
andauthored
Use TypeGuard in sphinx.util.inspect (#12283)
Co-authored-by: Adam Turner <[email protected]>
1 parent b6948b8 commit acc92ff

File tree

3 files changed

+69
-30
lines changed

3 files changed

+69
-30
lines changed

sphinx/ext/autodoc/mock.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from collections.abc import Iterator, Sequence
1818
from typing import Any
1919

20+
from typing_extensions import TypeGuard
21+
2022
logger = logging.getLogger(__name__)
2123

2224

@@ -154,7 +156,7 @@ def mock(modnames: list[str]) -> Iterator[None]:
154156
finder.invalidate_caches()
155157

156158

157-
def ismockmodule(subject: Any) -> bool:
159+
def ismockmodule(subject: Any) -> TypeGuard[_MockModule]:
158160
"""Check if the object is a mocked module."""
159161
return isinstance(subject, _MockModule)
160162

sphinx/util/inspect.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,36 @@
2727
from collections.abc import Callable, Sequence
2828
from inspect import _ParameterKind
2929
from types import MethodType, ModuleType
30-
from typing import Final
30+
from typing import Final, Protocol, Union
31+
32+
from typing_extensions import TypeAlias, TypeGuard
33+
34+
class _SupportsGet(Protocol):
35+
def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704
36+
37+
class _SupportsSet(Protocol):
38+
# instance and value are contravariants but we do not need that precision
39+
def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704
40+
41+
class _SupportsDelete(Protocol):
42+
# instance is contravariant but we do not need that precision
43+
def __delete__(self, __instance: Any) -> None: ... # NoQA: E704
44+
45+
_RoutineType: TypeAlias = Union[
46+
types.FunctionType,
47+
types.LambdaType,
48+
types.MethodType,
49+
types.BuiltinFunctionType,
50+
types.BuiltinMethodType,
51+
types.WrapperDescriptorType,
52+
types.MethodDescriptorType,
53+
types.ClassMethodDescriptorType,
54+
]
55+
_SignatureType: TypeAlias = Union[
56+
Callable[..., Any],
57+
staticmethod,
58+
classmethod,
59+
]
3160

3261
logger = logging.getLogger(__name__)
3362

@@ -90,7 +119,7 @@ class methods and static methods.
90119

91120

92121
def getall(obj: Any) -> Sequence[str] | None:
93-
"""Get the ``__all__`` attribute of an object as sequence.
122+
"""Get the ``__all__`` attribute of an object as a sequence.
94123
95124
This returns ``None`` if the given ``obj.__all__`` does not exist and
96125
raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of
@@ -184,12 +213,12 @@ def isNewType(obj: Any) -> bool:
184213
return __module__ == 'typing' and __qualname__ == 'NewType.<locals>.new_type'
185214

186215

187-
def isenumclass(x: Any) -> bool:
216+
def isenumclass(x: Any) -> TypeGuard[type[enum.Enum]]:
188217
"""Check if the object is an :class:`enumeration class <enum.Enum>`."""
189218
return isclass(x) and issubclass(x, enum.Enum)
190219

191220

192-
def isenumattribute(x: Any) -> bool:
221+
def isenumattribute(x: Any) -> TypeGuard[enum.Enum]:
193222
"""Check if the object is an enumeration attribute."""
194223
return isinstance(x, enum.Enum)
195224

@@ -206,12 +235,16 @@ def unpartial(obj: Any) -> Any:
206235
return obj
207236

208237

209-
def ispartial(obj: Any) -> bool:
238+
def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]:
210239
"""Check if the object is a partial function or method."""
211240
return isinstance(obj, (partial, partialmethod))
212241

213242

214-
def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
243+
def isclassmethod(
244+
obj: Any,
245+
cls: Any = None,
246+
name: str | None = None,
247+
) -> TypeGuard[classmethod]:
215248
"""Check if the object is a :class:`classmethod`."""
216249
if isinstance(obj, classmethod):
217250
return True
@@ -227,7 +260,11 @@ def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
227260
return False
228261

229262

230-
def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
263+
def isstaticmethod(
264+
obj: Any,
265+
cls: Any = None,
266+
name: str | None = None,
267+
) -> TypeGuard[staticmethod]:
231268
"""Check if the object is a :class:`staticmethod`."""
232269
if isinstance(obj, staticmethod):
233270
return True
@@ -241,7 +278,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
241278
return False
242279

243280

244-
def isdescriptor(x: Any) -> bool:
281+
def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]:
245282
"""Check if the object is a :external+python:term:`descriptor`."""
246283
return any(
247284
callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__')
@@ -308,12 +345,12 @@ def is_singledispatch_function(obj: Any) -> bool:
308345
)
309346

310347

311-
def is_singledispatch_method(obj: Any) -> bool:
348+
def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]:
312349
"""Check if the object is a :class:`~functools.singledispatchmethod`."""
313350
return isinstance(obj, singledispatchmethod)
314351

315352

316-
def isfunction(obj: Any) -> bool:
353+
def isfunction(obj: Any) -> TypeGuard[types.FunctionType]:
317354
"""Check if the object is a user-defined function.
318355
319356
Partial objects are unwrapped before checking them.
@@ -323,7 +360,7 @@ def isfunction(obj: Any) -> bool:
323360
return inspect.isfunction(unpartial(obj))
324361

325362

326-
def isbuiltin(obj: Any) -> bool:
363+
def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]:
327364
"""Check if the object is a built-in function or method.
328365
329366
Partial objects are unwrapped before checking them.
@@ -333,7 +370,7 @@ def isbuiltin(obj: Any) -> bool:
333370
return inspect.isbuiltin(unpartial(obj))
334371

335372

336-
def isroutine(obj: Any) -> bool:
373+
def isroutine(obj: Any) -> TypeGuard[_RoutineType]:
337374
"""Check if the object is a kind of function or method.
338375
339376
Partial objects are unwrapped before checking them.
@@ -343,7 +380,7 @@ def isroutine(obj: Any) -> bool:
343380
return inspect.isroutine(unpartial(obj))
344381

345382

346-
def iscoroutinefunction(obj: Any) -> bool:
383+
def iscoroutinefunction(obj: Any) -> TypeGuard[Callable[..., types.CoroutineType]]:
347384
"""Check if the object is a :external+python:term:`coroutine` function."""
348385
obj = unwrap_all(obj, stop=_is_wrapped_coroutine)
349386
return inspect.iscoroutinefunction(obj)
@@ -358,12 +395,12 @@ def _is_wrapped_coroutine(obj: Any) -> bool:
358395
return hasattr(obj, '__wrapped__')
359396

360397

361-
def isproperty(obj: Any) -> bool:
398+
def isproperty(obj: Any) -> TypeGuard[property | cached_property]:
362399
"""Check if the object is property (possibly cached)."""
363400
return isinstance(obj, (property, cached_property))
364401

365402

366-
def isgenericalias(obj: Any) -> bool:
403+
def isgenericalias(obj: Any) -> TypeGuard[types.GenericAlias]:
367404
"""Check if the object is a generic alias."""
368405
return isinstance(obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined]
369406

@@ -579,7 +616,7 @@ def __getitem__(self, key: str) -> Any:
579616
raise KeyError
580617

581618

582-
def _should_unwrap(subject: Callable[..., Any]) -> bool:
619+
def _should_unwrap(subject: _SignatureType) -> bool:
583620
"""Check the function should be unwrapped on getting signature."""
584621
__globals__ = getglobals(subject)
585622
# contextmanger should be unwrapped
@@ -590,7 +627,7 @@ def _should_unwrap(subject: Callable[..., Any]) -> bool:
590627

591628

592629
def signature(
593-
subject: Callable[..., Any],
630+
subject: _SignatureType,
594631
bound_method: bool = False,
595632
type_aliases: Mapping[str, str] | None = None,
596633
) -> Signature:
@@ -603,12 +640,12 @@ def signature(
603640

604641
try:
605642
if _should_unwrap(subject):
606-
signature = inspect.signature(subject)
643+
signature = inspect.signature(subject) # type: ignore[arg-type]
607644
else:
608-
signature = inspect.signature(subject, follow_wrapped=True)
645+
signature = inspect.signature(subject, follow_wrapped=True) # type: ignore[arg-type]
609646
except ValueError:
610647
# follow built-in wrappers up (ex. functools.lru_cache)
611-
signature = inspect.signature(subject)
648+
signature = inspect.signature(subject) # type: ignore[arg-type]
612649
parameters = list(signature.parameters.values())
613650
return_annotation = signature.return_annotation
614651

sphinx/util/typing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def is_system_TypeVar(typ: Any) -> bool:
164164
return modname == 'typing' and isinstance(typ, TypeVar)
165165

166166

167-
def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typing') -> str:
167+
def restify(cls: Any, mode: _RestifyMode = 'fully-qualified-except-typing') -> str:
168168
"""Convert python class to a reST reference.
169169
170170
:param mode: Specify a method how annotations will be stringified.
@@ -229,25 +229,25 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin
229229
return f':py:class:`{cls.__name__}`'
230230
elif (inspect.isgenericalias(cls)
231231
and cls.__module__ == 'typing'
232-
and cls.__origin__ is Union): # type: ignore[attr-defined]
232+
and cls.__origin__ is Union):
233233
# *cls* is defined in ``typing``, and thus ``__args__`` must exist
234-
return ' | '.join(restify(a, mode) for a in cls.__args__) # type: ignore[attr-defined]
234+
return ' | '.join(restify(a, mode) for a in cls.__args__)
235235
elif inspect.isgenericalias(cls):
236-
if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined]
237-
text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type]
236+
if isinstance(cls.__origin__, typing._SpecialForm):
237+
text = restify(cls.__origin__, mode)
238238
elif getattr(cls, '_name', None):
239-
cls_name = cls._name # type: ignore[attr-defined]
239+
cls_name = cls._name
240240
text = f':py:class:`{modprefix}{cls.__module__}.{cls_name}`'
241241
else:
242-
text = restify(cls.__origin__, mode) # type: ignore[attr-defined]
242+
text = restify(cls.__origin__, mode)
243243

244244
origin = getattr(cls, '__origin__', None)
245245
if not hasattr(cls, '__args__'): # NoQA: SIM114
246246
pass
247247
elif all(is_system_TypeVar(a) for a in cls.__args__):
248248
# Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT])
249249
pass
250-
elif cls.__module__ == 'typing' and cls._name == 'Callable': # type: ignore[attr-defined]
250+
elif cls.__module__ == 'typing' and cls._name == 'Callable':
251251
args = ', '.join(restify(a, mode) for a in cls.__args__[:-1])
252252
text += fr'\ [[{args}], {restify(cls.__args__[-1], mode)}]'
253253
elif cls.__module__ == 'typing' and getattr(origin, '_name', None) == 'Literal':
@@ -259,7 +259,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin
259259

260260
return text
261261
elif isinstance(cls, typing._SpecialForm):
262-
return f':py:obj:`~{cls.__module__}.{cls._name}`'
262+
return f':py:obj:`~{cls.__module__}.{cls._name}`' # type: ignore[attr-defined]
263263
elif sys.version_info[:2] >= (3, 11) and cls is typing.Any:
264264
# handle bpo-46998
265265
return f':py:obj:`~{cls.__module__}.{cls.__name__}`'

0 commit comments

Comments
 (0)