1
+ from __future__ import annotations
2
+
1
3
import inspect
2
4
import sys
3
5
import textwrap
4
6
import typing
5
- from typing import Any , AnyStr , Dict , NewType , Optional , Tuple , TypeVar , get_type_hints
7
+ from ast import FunctionDef , Module , stmt
8
+ from typing import Any , AnyStr , NewType , TypeVar , get_type_hints
6
9
7
10
from sphinx .application import Sphinx
11
+ from sphinx .environment import BuildEnvironment
12
+ from sphinx .ext .autodoc import Options
8
13
from sphinx .util import logging
9
14
from sphinx .util .inspect import signature as sphinx_signature
10
15
from sphinx .util .inspect import stringify_signature
19
24
]
20
25
21
26
22
- def get_annotation_module (annotation ) -> str :
27
+ def get_annotation_module (annotation : Any ) -> str :
23
28
# Special cases
24
29
if annotation is None :
25
30
return "builtins"
26
31
27
- if sys .version_info >= (3 , 10 ) and isinstance (annotation , NewType ):
32
+ if sys .version_info >= (3 , 10 ) and isinstance (annotation , NewType ): # type: ignore # isinstance NewType is Callable
28
33
return "typing"
29
34
30
35
if hasattr (annotation , "__module__" ):
31
- return annotation .__module__
36
+ return annotation .__module__ # type: ignore # deduced Any
32
37
33
38
if hasattr (annotation , "__origin__" ):
34
- return annotation .__origin__ .__module__
39
+ return annotation .__origin__ .__module__ # type: ignore # deduced Any
35
40
36
41
raise ValueError (f"Cannot determine the module of { annotation } " )
37
42
38
43
39
- def get_annotation_class_name (annotation , module : str ) -> str :
44
+ def get_annotation_class_name (annotation : Any , module : str ) -> str :
40
45
# Special cases
41
46
if annotation is None :
42
47
return "None"
@@ -45,32 +50,30 @@ def get_annotation_class_name(annotation, module: str) -> str:
45
50
elif annotation is AnyStr :
46
51
return "AnyStr"
47
52
elif (sys .version_info < (3 , 10 ) and inspect .isfunction (annotation ) and hasattr (annotation , "__supertype__" )) or (
48
- sys .version_info >= (3 , 10 ) and isinstance (annotation , NewType )
53
+ sys .version_info >= (3 , 10 ) and isinstance (annotation , NewType ) # type: ignore # isinstance NewType is Callable
49
54
):
50
55
return "NewType"
51
56
52
57
if getattr (annotation , "__qualname__" , None ):
53
- return annotation .__qualname__
58
+ return annotation .__qualname__ # type: ignore # deduced Any
54
59
elif getattr (annotation , "_name" , None ): # Required for generic aliases on Python 3.7+
55
- return annotation ._name
60
+ return annotation ._name # type: ignore # deduced Any
56
61
elif module in ("typing" , "typing_extensions" ) and isinstance (getattr (annotation , "name" , None ), str ):
57
62
# Required for at least Pattern and Match
58
- return annotation .name
63
+ return annotation .name # type: ignore # deduced Any
59
64
60
65
origin = getattr (annotation , "__origin__" , None )
61
66
if origin :
62
67
if getattr (origin , "__qualname__" , None ): # Required for Protocol subclasses
63
- return origin .__qualname__
68
+ return origin .__qualname__ # type: ignore # deduced Any
64
69
elif getattr (origin , "_name" , None ): # Required for Union on Python 3.7+
65
- return origin ._name
66
- else :
67
- return origin .__class__ .__qualname__ .lstrip ("_" ) # Required for Union on Python < 3.7
70
+ return origin ._name # type: ignore # deduced Any
68
71
69
72
annotation_cls = annotation if inspect .isclass (annotation ) else annotation .__class__
70
- return annotation_cls .__qualname__ .lstrip ("_" )
73
+ return annotation_cls .__qualname__ .lstrip ("_" ) # type: ignore # deduced Any
71
74
72
75
73
- def get_annotation_args (annotation , module : str , class_name : str ) -> Tuple :
76
+ def get_annotation_args (annotation : Any , module : str , class_name : str ) -> tuple [ Any , ...] :
74
77
try :
75
78
original = getattr (sys .modules [module ], class_name )
76
79
except (KeyError , AttributeError ):
@@ -87,14 +90,14 @@ def get_annotation_args(annotation, module: str, class_name: str) -> Tuple:
87
90
elif class_name == "NewType" and hasattr (annotation , "__supertype__" ):
88
91
return (annotation .__supertype__ ,)
89
92
elif class_name == "Literal" and hasattr (annotation , "__values__" ):
90
- return annotation .__values__
93
+ return annotation .__values__ # type: ignore # deduced Any
91
94
elif class_name == "Generic" :
92
- return annotation .__parameters__
95
+ return annotation .__parameters__ # type: ignore # deduced Any
93
96
94
97
return getattr (annotation , "__args__" , ())
95
98
96
99
97
- def format_annotation (annotation , fully_qualified : bool = False , simplify_optional_unions : bool = True ) -> str :
100
+ def format_annotation (annotation : Any , fully_qualified : bool = False , simplify_optional_unions : bool = True ) -> str :
98
101
# Special cases
99
102
if annotation is None or annotation is type (None ): # noqa: E721
100
103
return ":py:obj:`None`"
@@ -154,25 +157,25 @@ def format_annotation(annotation, fully_qualified: bool = False, simplify_option
154
157
155
158
156
159
# reference: https://github.com/pytorch/pytorch/pull/46548/files
157
- def normalize_source_lines (sourcelines : str ) -> str :
160
+ def normalize_source_lines (source_lines : str ) -> str :
158
161
"""
159
162
This helper function accepts a list of source lines. It finds the
160
163
indentation level of the function definition (`def`), then it indents
161
164
all lines in the function body to a point at or greater than that
162
165
level. This allows for comments and continued string literals that
163
166
are at a lower indentation than the rest of the code.
164
167
Arguments:
165
- sourcelines : source code
168
+ source_lines : source code
166
169
Returns:
167
170
source lines that have been correctly aligned
168
171
"""
169
- sourcelines = sourcelines .split ("\n " )
172
+ lines = source_lines .split ("\n " )
170
173
171
- def remove_prefix (text , prefix ) :
174
+ def remove_prefix (text : str , prefix : str ) -> str :
172
175
return text [text .startswith (prefix ) and len (prefix ) :]
173
176
174
177
# Find the line and line number containing the function definition
175
- for i , l in enumerate (sourcelines ):
178
+ for i , l in enumerate (lines ):
176
179
if l .lstrip ().startswith ("def" ):
177
180
idx = i
178
181
whitespace_separator = "def"
@@ -183,35 +186,37 @@ def remove_prefix(text, prefix):
183
186
break
184
187
185
188
else :
186
- return "\n " .join (sourcelines )
187
- fn_def = sourcelines [idx ]
189
+ return "\n " .join (lines )
190
+ fn_def = lines [idx ]
188
191
189
192
# Get a string representing the amount of leading whitespace
190
193
whitespace = fn_def .split (whitespace_separator )[0 ]
191
194
192
195
# Add this leading whitespace to all lines before and after the `def`
193
- aligned_prefix = [whitespace + remove_prefix (s , whitespace ) for s in sourcelines [:idx ]]
194
- aligned_suffix = [whitespace + remove_prefix (s , whitespace ) for s in sourcelines [idx + 1 :]]
196
+ aligned_prefix = [whitespace + remove_prefix (s , whitespace ) for s in lines [:idx ]]
197
+ aligned_suffix = [whitespace + remove_prefix (s , whitespace ) for s in lines [idx + 1 :]]
195
198
196
199
# Put it together again
197
200
aligned_prefix .append (fn_def )
198
201
return "\n " .join (aligned_prefix + aligned_suffix )
199
202
200
203
201
- def process_signature (app , what : str , name : str , obj , options , signature , return_annotation ): # noqa: U100
204
+ def process_signature (
205
+ app : Sphinx , what : str , name : str , obj : Any , options : Options , signature : str , return_annotation : str # noqa: U100
206
+ ) -> tuple [str , None ] | None :
202
207
if not callable (obj ):
203
- return
208
+ return None
204
209
205
210
original_obj = obj
206
211
if inspect .isclass (obj ):
207
212
obj = getattr (obj , "__init__" , getattr (obj , "__new__" , None ))
208
213
209
214
if not getattr (obj , "__annotations__" , None ):
210
- return
215
+ return None
211
216
212
217
obj = inspect .unwrap (obj )
213
- signature = sphinx_signature (obj )
214
- parameters = [param .replace (annotation = inspect .Parameter .empty ) for param in signature .parameters .values ()]
218
+ sph_signature = sphinx_signature (obj )
219
+ parameters = [param .replace (annotation = inspect .Parameter .empty ) for param in sph_signature .parameters .values ()]
215
220
216
221
# The generated dataclass __init__() and class are weird and need extra checks
217
222
# This helper function operates on the generated class and methods
@@ -228,15 +233,15 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
228
233
229
234
if "<locals>" in obj .__qualname__ and not _is_dataclass (name , what , obj .__qualname__ ):
230
235
logger .warning ('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)' , name )
231
- return
236
+ return None
232
237
233
238
if parameters :
234
239
if inspect .isclass (original_obj ) or (what == "method" and name .endswith (".__init__" )):
235
240
del parameters [0 ]
236
241
elif what == "method" :
237
242
outer = inspect .getmodule (obj )
238
- for clsname in obj .__qualname__ .split ("." )[:- 1 ]:
239
- outer = getattr (outer , clsname )
243
+ for class_name in obj .__qualname__ .split ("." )[:- 1 ]:
244
+ outer = getattr (outer , class_name )
240
245
241
246
method_name = obj .__name__
242
247
if method_name .startswith ("__" ) and not method_name .endswith ("__" ):
@@ -250,27 +255,23 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
250
255
if not isinstance (method_object , (classmethod , staticmethod )):
251
256
del parameters [0 ]
252
257
253
- signature = signature .replace (parameters = parameters , return_annotation = inspect .Signature .empty )
254
-
255
- return stringify_signature (signature ).replace ("\\ " , "\\ \\ " ), None
258
+ sph_signature = sph_signature .replace (parameters = parameters , return_annotation = inspect .Signature .empty )
256
259
260
+ return stringify_signature (sph_signature ).replace ("\\ " , "\\ \\ " ), None
257
261
258
- def _future_annotations_imported (obj ):
259
- if sys .version_info < (3 , 7 ):
260
- # Only Python ≥ 3.7 supports PEP563.
261
- return False
262
262
263
+ def _future_annotations_imported (obj : Any ) -> bool :
263
264
_annotations = getattr (inspect .getmodule (obj ), "annotations" , None )
264
265
if _annotations is None :
265
266
return False
266
267
267
268
# Make sure that annotations is imported from __future__ - defined in cpython/Lib/__future__.py
268
269
# annotations become strings at runtime
269
270
future_annotations = 0x100000 if sys .version_info [0 :2 ] == (3 , 7 ) else 0x1000000
270
- return _annotations .compiler_flag == future_annotations
271
+ return bool ( _annotations .compiler_flag == future_annotations )
271
272
272
273
273
- def get_all_type_hints (obj , name ) :
274
+ def get_all_type_hints (obj : Any , name : str ) -> dict [ str , Any ] :
274
275
rv = {}
275
276
276
277
try :
@@ -310,7 +311,7 @@ def get_all_type_hints(obj, name):
310
311
return rv
311
312
312
313
313
- def backfill_type_hints (obj , name ) :
314
+ def backfill_type_hints (obj : Any , name : str ) -> dict [ str , Any ] :
314
315
parse_kwargs = {}
315
316
if sys .version_info < (3 , 8 ):
316
317
try :
@@ -322,17 +323,18 @@ def backfill_type_hints(obj, name):
322
323
323
324
parse_kwargs = {"type_comments" : True }
324
325
325
- def _one_child (module ) :
326
+ def _one_child (module : Module ) -> stmt | None :
326
327
children = module .body # use the body to ignore type comments
327
328
328
329
if len (children ) != 1 :
329
330
logger .warning ('Did not get exactly one node from AST for "%s", got %s' , name , len (children ))
330
- return
331
+ return None
331
332
332
333
return children [0 ]
333
334
334
335
try :
335
- obj_ast = ast .parse (textwrap .dedent (normalize_source_lines (inspect .getsource (obj ))), ** parse_kwargs )
336
+ code = textwrap .dedent (normalize_source_lines (inspect .getsource (obj )))
337
+ obj_ast = ast .parse (code , ** parse_kwargs ) # type: ignore # dynamic kwargs
336
338
except (OSError , TypeError ):
337
339
return {}
338
340
@@ -385,7 +387,7 @@ def _one_child(module):
385
387
return rv
386
388
387
389
388
- def load_args (obj_ast ) :
390
+ def load_args (obj_ast : FunctionDef ) -> list [ Any ] :
389
391
func_args = obj_ast .args
390
392
args = []
391
393
pos_only = getattr (func_args , "posonlyargs" , None )
@@ -403,12 +405,12 @@ def load_args(obj_ast):
403
405
return args
404
406
405
407
406
- def split_type_comment_args (comment ) :
407
- def add (val ) :
408
+ def split_type_comment_args (comment : str ) -> list [ str | None ] :
409
+ def add (val : str ) -> None :
408
410
result .append (val .strip ().lstrip ("*" )) # remove spaces, and var/kw arg marker
409
411
410
412
comment = comment .strip ().lstrip ("(" ).rstrip (")" )
411
- result = []
413
+ result : list [ str | None ] = []
412
414
if not comment :
413
415
return result
414
416
@@ -426,7 +428,7 @@ def add(val):
426
428
return result
427
429
428
430
429
- def format_default (app : Sphinx , default : Any ) -> Optional [ str ] :
431
+ def format_default (app : Sphinx , default : Any ) -> str | None :
430
432
if default is inspect .Parameter .empty :
431
433
return None
432
434
formatted = repr (default ).replace ("\\ " , "\\ \\ " )
@@ -436,7 +438,9 @@ def format_default(app: Sphinx, default: Any) -> Optional[str]:
436
438
return f", default: ``{ formatted } ``"
437
439
438
440
439
- def process_docstring (app : Sphinx , what , name , obj , options , lines ): # noqa: U100
441
+ def process_docstring (
442
+ app : Sphinx , what : str , name : str , obj : Any , options : Options | None , lines : list [str ] # noqa: U100
443
+ ) -> None :
440
444
original_obj = obj
441
445
if isinstance (obj , property ):
442
446
obj = obj .fget
@@ -519,13 +523,13 @@ def builder_ready(app: Sphinx) -> None:
519
523
typing .TYPE_CHECKING = True
520
524
521
525
522
- def validate_config (app : Sphinx , * args ) -> None : # noqa: U100
526
+ def validate_config (app : Sphinx , env : BuildEnvironment , docnames : list [ str ] ) -> None : # noqa: U100
523
527
valid = {None , "comma" , "braces" , "braces-after" }
524
528
if app .config .typehints_defaults not in valid | {False }:
525
529
raise ValueError (f"typehints_defaults needs to be one of { valid !r} , not { app .config .typehints_defaults !r} " )
526
530
527
531
528
- def setup (app : Sphinx ) -> Dict [str , bool ]:
532
+ def setup (app : Sphinx ) -> dict [str , bool ]:
529
533
app .add_config_value ("set_type_checking_flag" , False , "html" )
530
534
app .add_config_value ("always_document_param_types" , False , "html" )
531
535
app .add_config_value ("typehints_fully_qualified" , False , "env" )
0 commit comments