5
5
import textwrap
6
6
import typing
7
7
from ast import FunctionDef , Module , stmt
8
+ from functools import partial
8
9
from typing import Any , AnyStr , NewType , TypeVar , get_type_hints
9
10
10
11
from sphinx .application import Sphinx
16
17
17
18
from .version import version as __version__
18
19
19
- logger = logging .getLogger (__name__ )
20
- pydata_annotations = {"Any" , "AnyStr" , "Callable" , "ClassVar" , "Literal" , "NoReturn" , "Optional" , "Tuple" , "Union" }
21
-
22
- __all__ = [
23
- "__version__" ,
24
- ]
20
+ _LOGGER = logging .getLogger (__name__ )
21
+ _PYDATA_ANNOTATIONS = {"Any" , "AnyStr" , "Callable" , "ClassVar" , "Literal" , "NoReturn" , "Optional" , "Tuple" , "Union" }
25
22
26
23
27
24
def get_annotation_module (annotation : Any ) -> str :
28
- # Special cases
29
25
if annotation is None :
30
26
return "builtins"
31
-
32
27
if sys .version_info >= (3 , 10 ) and isinstance (annotation , NewType ): # type: ignore # isinstance NewType is Callable
33
28
return "typing"
34
-
35
29
if hasattr (annotation , "__module__" ):
36
30
return annotation .__module__ # type: ignore # deduced Any
37
-
38
31
if hasattr (annotation , "__origin__" ):
39
32
return annotation .__origin__ .__module__ # type: ignore # deduced Any
40
-
41
33
raise ValueError (f"Cannot determine the module of { annotation } " )
42
34
43
35
@@ -124,7 +116,7 @@ def format_annotation(annotation: Any, fully_qualified: bool = False, simplify_o
124
116
125
117
full_name = f"{ module } .{ class_name } " if module != "builtins" else class_name
126
118
prefix = "" if fully_qualified or full_name == class_name else "~"
127
- role = "data" if class_name in pydata_annotations else "class"
119
+ role = "data" if class_name in _PYDATA_ANNOTATIONS else "class"
128
120
args_format = "\\ [{}]"
129
121
formatted_args = ""
130
122
@@ -232,7 +224,7 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
232
224
return False
233
225
234
226
if "<locals>" in obj .__qualname__ and not _is_dataclass (name , what , obj .__qualname__ ):
235
- logger .warning ('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)' , name )
227
+ _LOGGER .warning ('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)' , name )
236
228
return None
237
229
238
230
if parameters :
@@ -287,7 +279,7 @@ def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
287
279
if isinstance (exc , TypeError ) and _future_annotations_imported (obj ) and "unsupported operand type" in str (exc ):
288
280
rv = obj .__annotations__
289
281
except NameError as exc :
290
- logger .warning ('Cannot resolve forward reference in type annotations of "%s": %s' , name , exc )
282
+ _LOGGER .warning ('Cannot resolve forward reference in type annotations of "%s": %s' , name , exc )
291
283
rv = obj .__annotations__
292
284
293
285
if rv :
@@ -305,7 +297,7 @@ def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
305
297
except (AttributeError , TypeError ):
306
298
pass
307
299
except NameError as exc :
308
- logger .warning ('Cannot resolve forward reference in type annotations of "%s": %s' , name , exc )
300
+ _LOGGER .warning ('Cannot resolve forward reference in type annotations of "%s": %s' , name , exc )
309
301
rv = obj .__annotations__
310
302
311
303
return rv
@@ -327,7 +319,7 @@ def _one_child(module: Module) -> stmt | None:
327
319
children = module .body # use the body to ignore type comments
328
320
329
321
if len (children ) != 1 :
330
- logger .warning ('Did not get exactly one node from AST for "%s", got %s' , name , len (children ))
322
+ _LOGGER .warning ('Did not get exactly one node from AST for "%s", got %s' , name , len (children ))
331
323
return None
332
324
333
325
return children [0 ]
@@ -353,7 +345,7 @@ def _one_child(module: Module) -> stmt | None:
353
345
try :
354
346
comment_args_str , comment_returns = type_comment .split (" -> " )
355
347
except ValueError :
356
- logger .warning ('Unparseable type hint comment for "%s": Expected to contain ` -> `' , name )
348
+ _LOGGER .warning ('Unparseable type hint comment for "%s": Expected to contain ` -> `' , name )
357
349
return {}
358
350
359
351
rv = {}
@@ -368,7 +360,7 @@ def _one_child(module: Module) -> stmt | None:
368
360
comment_args .insert (0 , None ) # self/cls may be omitted in type comments, insert blank
369
361
370
362
if len (args ) != len (comment_args ):
371
- logger .warning ('Not enough type comments found on "%s"' , name )
363
+ _LOGGER .warning ('Not enough type comments found on "%s"' , name )
372
364
return rv
373
365
374
366
for at , arg in enumerate (args ):
@@ -442,80 +434,71 @@ def process_docstring(
442
434
app : Sphinx , what : str , name : str , obj : Any , options : Options | None , lines : list [str ] # noqa: U100
443
435
) -> None :
444
436
original_obj = obj
445
- if isinstance (obj , property ):
446
- obj = obj .fget
447
-
448
- if callable (obj ):
449
- if inspect .isclass (obj ):
450
- obj = obj .__init__
437
+ obj = obj .fget if isinstance (obj , property ) else obj
438
+ if not callable (obj ):
439
+ return
440
+ obj = obj .__init__ if inspect .isclass (obj ) else obj
441
+ obj = inspect .unwrap (obj )
451
442
452
- obj = inspect . unwrap ( obj )
443
+ try :
453
444
signature = sphinx_signature (obj )
454
- type_hints = get_all_type_hints (obj , name )
455
-
456
- for arg_name , annotation in type_hints .items ():
457
- if arg_name == "return" :
458
- continue # this is handled separately later
459
- default = signature .parameters [arg_name ].default
460
- if arg_name .endswith ("_" ):
461
- arg_name = f"{ arg_name [:- 1 ]} \\ _"
462
-
463
- formatted_annotation = format_annotation (
464
- annotation ,
465
- fully_qualified = app .config .typehints_fully_qualified ,
466
- simplify_optional_unions = app .config .simplify_optional_unions ,
467
- )
468
-
469
- search_for = [f":{ field } { arg_name } :" for field in ("param" , "parameter" , "arg" , "argument" )]
470
- insert_index = None
471
-
472
- for i , line in enumerate (lines ):
473
- if any (line .startswith (search_string ) for search_string in search_for ):
474
- insert_index = i
475
- break
476
-
477
- if insert_index is None and app .config .always_document_param_types :
478
- lines .append (f":param { arg_name } :" )
479
- insert_index = len (lines )
480
-
481
- if insert_index is not None :
482
- type_annotation = f":type { arg_name } : { formatted_annotation } "
483
- if app .config .typehints_defaults :
484
- formatted_default = format_default (app , default )
485
- if formatted_default :
486
- if app .config .typehints_defaults .endswith ("after" ):
487
- lines [insert_index ] += formatted_default
488
- else : # add to last param doc line
489
- type_annotation += formatted_default
490
- lines .insert (insert_index , type_annotation )
491
-
492
- if "return" in type_hints and not inspect .isclass (original_obj ):
493
- # This avoids adding a return type for data class __init__ methods
494
- if what == "method" and name .endswith (".__init__" ):
495
- return
496
-
497
- formatted_annotation = format_annotation (
498
- type_hints ["return" ],
499
- fully_qualified = app .config .typehints_fully_qualified ,
500
- simplify_optional_unions = app .config .simplify_optional_unions ,
501
- )
502
-
445
+ except (ValueError , TypeError ):
446
+ signature = None
447
+ type_hints = get_all_type_hints (obj , name )
448
+
449
+ formatter = partial (
450
+ format_annotation ,
451
+ fully_qualified = app .config .typehints_fully_qualified ,
452
+ simplify_optional_unions = app .config .simplify_optional_unions ,
453
+ )
454
+ for arg_name , annotation in type_hints .items ():
455
+ if arg_name == "return" :
456
+ continue # this is handled separately later
457
+ default = inspect .Parameter .empty if signature is None else signature .parameters [arg_name ].default
458
+ if arg_name .endswith ("_" ):
459
+ arg_name = f"{ arg_name [:- 1 ]} \\ _"
460
+
461
+ formatted_annotation = formatter (annotation )
462
+
463
+ search_for = {f":{ field } { arg_name } :" for field in ("param" , "parameter" , "arg" , "argument" )}
464
+ insert_index = None
465
+ for at , line in enumerate (lines ):
466
+ if any (line .startswith (search_string ) for search_string in search_for ):
467
+ insert_index = at
468
+ break
469
+
470
+ if insert_index is None and app .config .always_document_param_types :
471
+ lines .append (f":param { arg_name } :" )
503
472
insert_index = len (lines )
504
- for i , line in enumerate (lines ):
505
- if line .startswith (":rtype:" ):
506
- insert_index = None
507
- break
508
- elif line .startswith (":return:" ) or line .startswith (":returns:" ):
509
- insert_index = i
510
-
511
- if insert_index is not None and app .config .typehints_document_rtype :
512
- if insert_index == len (lines ):
513
- # Ensure that :rtype: doesn't get joined with a paragraph of text, which
514
- # prevents it being interpreted.
515
- lines .append ("" )
516
- insert_index += 1
517
473
518
- lines .insert (insert_index , f":rtype: { formatted_annotation } " )
474
+ if insert_index is not None :
475
+ type_annotation = f":type { arg_name } : { formatted_annotation } "
476
+ if app .config .typehints_defaults :
477
+ formatted_default = format_default (app , default )
478
+ if formatted_default :
479
+ if app .config .typehints_defaults .endswith ("after" ):
480
+ lines [insert_index ] += formatted_default
481
+ else : # add to last param doc line
482
+ type_annotation += formatted_default
483
+ lines .insert (insert_index , type_annotation )
484
+
485
+ if "return" in type_hints and not inspect .isclass (original_obj ):
486
+ if what == "method" and name .endswith (".__init__" ): # avoid adding a return type for data class __init__
487
+ return
488
+ formatted_annotation = formatter (type_hints ["return" ])
489
+ insert_index = len (lines )
490
+ for at , line in enumerate (lines ):
491
+ if line .startswith (":rtype:" ):
492
+ insert_index = None
493
+ break
494
+ elif line .startswith (":return:" ) or line .startswith (":returns:" ):
495
+ insert_index = at
496
+
497
+ if insert_index is not None and app .config .typehints_document_rtype :
498
+ if insert_index == len (lines ): # ensure that :rtype: doesn't get joined with a paragraph of text
499
+ lines .append ("" )
500
+ insert_index += 1
501
+ lines .insert (insert_index , f":rtype: { formatted_annotation } " )
519
502
520
503
521
504
def builder_ready (app : Sphinx ) -> None :
@@ -541,3 +524,15 @@ def setup(app: Sphinx) -> dict[str, bool]:
541
524
app .connect ("autodoc-process-signature" , process_signature )
542
525
app .connect ("autodoc-process-docstring" , process_docstring )
543
526
return {"parallel_read_safe" : True }
527
+
528
+
529
+ __all__ = [
530
+ "__version__" ,
531
+ "format_annotation" ,
532
+ "get_annotation_args" ,
533
+ "get_annotation_class_name" ,
534
+ "get_annotation_module" ,
535
+ "normalize_source_lines" ,
536
+ "process_docstring" ,
537
+ "process_signature" ,
538
+ ]
0 commit comments