Skip to content

Commit 2fbf1c4

Browse files
authored
Add type information and check via mypy (#195)
1 parent 6dc9239 commit 2fbf1c4

File tree

10 files changed

+147
-110
lines changed

10 files changed

+147
-110
lines changed

.github/workflows/check.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ jobs:
105105
tox_env:
106106
- dev
107107
- readme
108+
- type
108109
steps:
109110
- uses: actions/checkout@v2
110111
with:

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,13 @@ write_to = "src/sphinx_autodoc_typehints/version.py"
1515

1616
[tool.pytest.ini_options]
1717
testpaths = ["tests"]
18+
19+
[tool.mypy]
20+
python_version = "3.10"
21+
strict = true
22+
exclude = "^.*/roots/.*$"
23+
24+
25+
[[tool.mypy.overrides]]
26+
module = ["sphobjinv"]
27+
ignore_missing_imports = true

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ testing =
5050
type_comments =
5151
typed-ast>=1.4.0;python_version < "3.8"
5252

53+
[options.package_data]
54+
sphinx_autodoc_typehints = py.typed
55+
5356
[coverage:run]
5457
plugins = covdefaults
5558
parallel = true

src/sphinx_autodoc_typehints/__init__.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from __future__ import annotations
2+
13
import inspect
24
import sys
35
import textwrap
46
import typing
5-
from typing import Any, AnyStr, Dict, NewType, Optional, Tuple, TypeVar, get_type_hints
7+
from ast import FunctionDef, Module, stmt
8+
from typing import Any, AnyStr, NewType, TypeVar, get_type_hints
69

710
from sphinx.application import Sphinx
11+
from sphinx.environment import BuildEnvironment
12+
from sphinx.ext.autodoc import Options
813
from sphinx.util import logging
914
from sphinx.util.inspect import signature as sphinx_signature
1015
from sphinx.util.inspect import stringify_signature
@@ -19,24 +24,24 @@
1924
]
2025

2126

22-
def get_annotation_module(annotation) -> str:
27+
def get_annotation_module(annotation: Any) -> str:
2328
# Special cases
2429
if annotation is None:
2530
return "builtins"
2631

27-
if sys.version_info >= (3, 10) and isinstance(annotation, NewType):
32+
if sys.version_info >= (3, 10) and isinstance(annotation, NewType): # type: ignore # isinstance NewType is Callable
2833
return "typing"
2934

3035
if hasattr(annotation, "__module__"):
31-
return annotation.__module__
36+
return annotation.__module__ # type: ignore # deduced Any
3237

3338
if hasattr(annotation, "__origin__"):
34-
return annotation.__origin__.__module__
39+
return annotation.__origin__.__module__ # type: ignore # deduced Any
3540

3641
raise ValueError(f"Cannot determine the module of {annotation}")
3742

3843

39-
def get_annotation_class_name(annotation, module: str) -> str:
44+
def get_annotation_class_name(annotation: Any, module: str) -> str:
4045
# Special cases
4146
if annotation is None:
4247
return "None"
@@ -45,32 +50,30 @@ def get_annotation_class_name(annotation, module: str) -> str:
4550
elif annotation is AnyStr:
4651
return "AnyStr"
4752
elif (sys.version_info < (3, 10) and inspect.isfunction(annotation) and hasattr(annotation, "__supertype__")) or (
48-
sys.version_info >= (3, 10) and isinstance(annotation, NewType)
53+
sys.version_info >= (3, 10) and isinstance(annotation, NewType) # type: ignore # isinstance NewType is Callable
4954
):
5055
return "NewType"
5156

5257
if getattr(annotation, "__qualname__", None):
53-
return annotation.__qualname__
58+
return annotation.__qualname__ # type: ignore # deduced Any
5459
elif getattr(annotation, "_name", None): # Required for generic aliases on Python 3.7+
55-
return annotation._name
60+
return annotation._name # type: ignore # deduced Any
5661
elif module in ("typing", "typing_extensions") and isinstance(getattr(annotation, "name", None), str):
5762
# Required for at least Pattern and Match
58-
return annotation.name
63+
return annotation.name # type: ignore # deduced Any
5964

6065
origin = getattr(annotation, "__origin__", None)
6166
if origin:
6267
if getattr(origin, "__qualname__", None): # Required for Protocol subclasses
63-
return origin.__qualname__
68+
return origin.__qualname__ # type: ignore # deduced Any
6469
elif getattr(origin, "_name", None): # Required for Union on Python 3.7+
65-
return origin._name
66-
else:
67-
return origin.__class__.__qualname__.lstrip("_") # Required for Union on Python < 3.7
70+
return origin._name # type: ignore # deduced Any
6871

6972
annotation_cls = annotation if inspect.isclass(annotation) else annotation.__class__
70-
return annotation_cls.__qualname__.lstrip("_")
73+
return annotation_cls.__qualname__.lstrip("_") # type: ignore # deduced Any
7174

7275

73-
def get_annotation_args(annotation, module: str, class_name: str) -> Tuple:
76+
def get_annotation_args(annotation: Any, module: str, class_name: str) -> tuple[Any, ...]:
7477
try:
7578
original = getattr(sys.modules[module], class_name)
7679
except (KeyError, AttributeError):
@@ -87,14 +90,14 @@ def get_annotation_args(annotation, module: str, class_name: str) -> Tuple:
8790
elif class_name == "NewType" and hasattr(annotation, "__supertype__"):
8891
return (annotation.__supertype__,)
8992
elif class_name == "Literal" and hasattr(annotation, "__values__"):
90-
return annotation.__values__
93+
return annotation.__values__ # type: ignore # deduced Any
9194
elif class_name == "Generic":
92-
return annotation.__parameters__
95+
return annotation.__parameters__ # type: ignore # deduced Any
9396

9497
return getattr(annotation, "__args__", ())
9598

9699

97-
def format_annotation(annotation, fully_qualified: bool = False, simplify_optional_unions: bool = True) -> str:
100+
def format_annotation(annotation: Any, fully_qualified: bool = False, simplify_optional_unions: bool = True) -> str:
98101
# Special cases
99102
if annotation is None or annotation is type(None): # noqa: E721
100103
return ":py:obj:`None`"
@@ -154,25 +157,25 @@ def format_annotation(annotation, fully_qualified: bool = False, simplify_option
154157

155158

156159
# reference: https://github.com/pytorch/pytorch/pull/46548/files
157-
def normalize_source_lines(sourcelines: str) -> str:
160+
def normalize_source_lines(source_lines: str) -> str:
158161
"""
159162
This helper function accepts a list of source lines. It finds the
160163
indentation level of the function definition (`def`), then it indents
161164
all lines in the function body to a point at or greater than that
162165
level. This allows for comments and continued string literals that
163166
are at a lower indentation than the rest of the code.
164167
Arguments:
165-
sourcelines: source code
168+
source_lines: source code
166169
Returns:
167170
source lines that have been correctly aligned
168171
"""
169-
sourcelines = sourcelines.split("\n")
172+
lines = source_lines.split("\n")
170173

171-
def remove_prefix(text, prefix):
174+
def remove_prefix(text: str, prefix: str) -> str:
172175
return text[text.startswith(prefix) and len(prefix) :]
173176

174177
# Find the line and line number containing the function definition
175-
for i, l in enumerate(sourcelines):
178+
for i, l in enumerate(lines):
176179
if l.lstrip().startswith("def"):
177180
idx = i
178181
whitespace_separator = "def"
@@ -183,35 +186,37 @@ def remove_prefix(text, prefix):
183186
break
184187

185188
else:
186-
return "\n".join(sourcelines)
187-
fn_def = sourcelines[idx]
189+
return "\n".join(lines)
190+
fn_def = lines[idx]
188191

189192
# Get a string representing the amount of leading whitespace
190193
whitespace = fn_def.split(whitespace_separator)[0]
191194

192195
# Add this leading whitespace to all lines before and after the `def`
193-
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
194-
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]]
196+
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in lines[:idx]]
197+
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in lines[idx + 1 :]]
195198

196199
# Put it together again
197200
aligned_prefix.append(fn_def)
198201
return "\n".join(aligned_prefix + aligned_suffix)
199202

200203

201-
def process_signature(app, what: str, name: str, obj, options, signature, return_annotation): # noqa: U100
204+
def process_signature(
205+
app: Sphinx, what: str, name: str, obj: Any, options: Options, signature: str, return_annotation: str # noqa: U100
206+
) -> tuple[str, None] | None:
202207
if not callable(obj):
203-
return
208+
return None
204209

205210
original_obj = obj
206211
if inspect.isclass(obj):
207212
obj = getattr(obj, "__init__", getattr(obj, "__new__", None))
208213

209214
if not getattr(obj, "__annotations__", None):
210-
return
215+
return None
211216

212217
obj = inspect.unwrap(obj)
213-
signature = sphinx_signature(obj)
214-
parameters = [param.replace(annotation=inspect.Parameter.empty) for param in signature.parameters.values()]
218+
sph_signature = sphinx_signature(obj)
219+
parameters = [param.replace(annotation=inspect.Parameter.empty) for param in sph_signature.parameters.values()]
215220

216221
# The generated dataclass __init__() and class are weird and need extra checks
217222
# This helper function operates on the generated class and methods
@@ -228,15 +233,15 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
228233

229234
if "<locals>" in obj.__qualname__ and not _is_dataclass(name, what, obj.__qualname__):
230235
logger.warning('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)', name)
231-
return
236+
return None
232237

233238
if parameters:
234239
if inspect.isclass(original_obj) or (what == "method" and name.endswith(".__init__")):
235240
del parameters[0]
236241
elif what == "method":
237242
outer = inspect.getmodule(obj)
238-
for clsname in obj.__qualname__.split(".")[:-1]:
239-
outer = getattr(outer, clsname)
243+
for class_name in obj.__qualname__.split(".")[:-1]:
244+
outer = getattr(outer, class_name)
240245

241246
method_name = obj.__name__
242247
if method_name.startswith("__") and not method_name.endswith("__"):
@@ -250,27 +255,23 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
250255
if not isinstance(method_object, (classmethod, staticmethod)):
251256
del parameters[0]
252257

253-
signature = signature.replace(parameters=parameters, return_annotation=inspect.Signature.empty)
254-
255-
return stringify_signature(signature).replace("\\", "\\\\"), None
258+
sph_signature = sph_signature.replace(parameters=parameters, return_annotation=inspect.Signature.empty)
256259

260+
return stringify_signature(sph_signature).replace("\\", "\\\\"), None
257261

258-
def _future_annotations_imported(obj):
259-
if sys.version_info < (3, 7):
260-
# Only Python ≥ 3.7 supports PEP563.
261-
return False
262262

263+
def _future_annotations_imported(obj: Any) -> bool:
263264
_annotations = getattr(inspect.getmodule(obj), "annotations", None)
264265
if _annotations is None:
265266
return False
266267

267268
# Make sure that annotations is imported from __future__ - defined in cpython/Lib/__future__.py
268269
# annotations become strings at runtime
269270
future_annotations = 0x100000 if sys.version_info[0:2] == (3, 7) else 0x1000000
270-
return _annotations.compiler_flag == future_annotations
271+
return bool(_annotations.compiler_flag == future_annotations)
271272

272273

273-
def get_all_type_hints(obj, name):
274+
def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
274275
rv = {}
275276

276277
try:
@@ -310,7 +311,7 @@ def get_all_type_hints(obj, name):
310311
return rv
311312

312313

313-
def backfill_type_hints(obj, name):
314+
def backfill_type_hints(obj: Any, name: str) -> dict[str, Any]:
314315
parse_kwargs = {}
315316
if sys.version_info < (3, 8):
316317
try:
@@ -322,17 +323,18 @@ def backfill_type_hints(obj, name):
322323

323324
parse_kwargs = {"type_comments": True}
324325

325-
def _one_child(module):
326+
def _one_child(module: Module) -> stmt | None:
326327
children = module.body # use the body to ignore type comments
327328

328329
if len(children) != 1:
329330
logger.warning('Did not get exactly one node from AST for "%s", got %s', name, len(children))
330-
return
331+
return None
331332

332333
return children[0]
333334

334335
try:
335-
obj_ast = ast.parse(textwrap.dedent(normalize_source_lines(inspect.getsource(obj))), **parse_kwargs)
336+
code = textwrap.dedent(normalize_source_lines(inspect.getsource(obj)))
337+
obj_ast = ast.parse(code, **parse_kwargs) # type: ignore # dynamic kwargs
336338
except (OSError, TypeError):
337339
return {}
338340

@@ -385,7 +387,7 @@ def _one_child(module):
385387
return rv
386388

387389

388-
def load_args(obj_ast):
390+
def load_args(obj_ast: FunctionDef) -> list[Any]:
389391
func_args = obj_ast.args
390392
args = []
391393
pos_only = getattr(func_args, "posonlyargs", None)
@@ -403,12 +405,12 @@ def load_args(obj_ast):
403405
return args
404406

405407

406-
def split_type_comment_args(comment):
407-
def add(val):
408+
def split_type_comment_args(comment: str) -> list[str | None]:
409+
def add(val: str) -> None:
408410
result.append(val.strip().lstrip("*")) # remove spaces, and var/kw arg marker
409411

410412
comment = comment.strip().lstrip("(").rstrip(")")
411-
result = []
413+
result: list[str | None] = []
412414
if not comment:
413415
return result
414416

@@ -426,7 +428,7 @@ def add(val):
426428
return result
427429

428430

429-
def format_default(app: Sphinx, default: Any) -> Optional[str]:
431+
def format_default(app: Sphinx, default: Any) -> str | None:
430432
if default is inspect.Parameter.empty:
431433
return None
432434
formatted = repr(default).replace("\\", "\\\\")
@@ -436,7 +438,9 @@ def format_default(app: Sphinx, default: Any) -> Optional[str]:
436438
return f", default: ``{formatted}``"
437439

438440

439-
def process_docstring(app: Sphinx, what, name, obj, options, lines): # noqa: U100
441+
def process_docstring(
442+
app: Sphinx, what: str, name: str, obj: Any, options: Options | None, lines: list[str] # noqa: U100
443+
) -> None:
440444
original_obj = obj
441445
if isinstance(obj, property):
442446
obj = obj.fget
@@ -519,13 +523,13 @@ def builder_ready(app: Sphinx) -> None:
519523
typing.TYPE_CHECKING = True
520524

521525

522-
def validate_config(app: Sphinx, *args) -> None: # noqa: U100
526+
def validate_config(app: Sphinx, env: BuildEnvironment, docnames: list[str]) -> None: # noqa: U100
523527
valid = {None, "comma", "braces", "braces-after"}
524528
if app.config.typehints_defaults not in valid | {False}:
525529
raise ValueError(f"typehints_defaults needs to be one of {valid!r}, not {app.config.typehints_defaults!r}")
526530

527531

528-
def setup(app: Sphinx) -> Dict[str, bool]:
532+
def setup(app: Sphinx) -> dict[str, bool]:
529533
app.add_config_value("set_type_checking_flag", False, "html")
530534
app.add_config_value("always_document_param_types", False, "html")
531535
app.add_config_value("typehints_fully_qualified", False, "env")

src/sphinx_autodoc_typehints/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)