@@ -54,9 +54,13 @@ def _preprocess_default(self, val: ast.AST) -> str:
54
54
text_default_val = self ._PAREN_NUMBER_RE .sub ('\\ 1' , text_default_val )
55
55
return text_default_val
56
56
57
+ def extract (self , obj : Any ):
58
+ obj_source = textwrap .dedent (inspect .getsource (obj ))
59
+ obj_ast = ast .parse (obj_source )
60
+ self .visit (obj_ast )
57
61
58
- class _CallableDefaultAndAnnotationExtractor ( _BaseDefaultAndAnnotationExtractor
59
- ):
62
+
63
+ class _ArgDefaultAndAnnotationExtractor ( _BaseDefaultAndAnnotationExtractor ):
60
64
"""Extracts the type annotations by parsing the AST of a function."""
61
65
62
66
def visit_FunctionDef (self , node ) -> None : # pylint: disable=invalid-name
@@ -95,8 +99,7 @@ def visit_FunctionDef(self, node) -> None: # pylint: disable=invalid-name
95
99
self .defaults [kwarg .arg ] = text_default_val
96
100
97
101
98
- class _DataclassDefaultAndAnnotationExtractor (_BaseDefaultAndAnnotationExtractor
99
- ):
102
+ class _ClassDefaultAndAnnotationExtractor (_BaseDefaultAndAnnotationExtractor ):
100
103
"""Extracts the type annotations by parsing the AST of a dataclass."""
101
104
102
105
def __init__ (self ):
@@ -111,6 +114,8 @@ def visit_ClassDef(self, node) -> None: # pylint: disable=invalid-name
111
114
for sub in node .body :
112
115
if isinstance (sub , ast .AnnAssign ):
113
116
self .visit_AnnAssign (sub )
117
+ elif isinstance (sub , ast .Assign ):
118
+ self .visit_Assign (sub )
114
119
115
120
def visit_AnnAssign (self , node ) -> None : # pylint: disable=invalid-name
116
121
"""Vists an assignment with a type annotation. Dataclasses is an example."""
@@ -120,6 +125,20 @@ def visit_AnnAssign(self, node) -> None: # pylint: disable=invalid-name
120
125
if node .value is not None :
121
126
self .defaults [arg ] = self ._preprocess_default (node .value )
122
127
128
+ def visit_Assign (self , node ) -> None : # pylint: disable=invalid-name
129
+ """Vists an assignment with a type annotation. Dataclasses is an example."""
130
+ names = [_source_from_ast (t ) for t in node .targets ]
131
+ if node .value is not None :
132
+ val = self ._preprocess_default (node .value )
133
+ for name in names :
134
+ self .defaults [name ] = val
135
+
136
+ def extract (self , cls ):
137
+ # Iterate over the classes in reverse order so each class overwrites it's
138
+ # parents. Skip `object`.
139
+ for cls in reversed (cls .__mro__ [:- 1 ]):
140
+ super ().extract (cls )
141
+
123
142
124
143
_OBJECT_MEMORY_ADDRESS_RE = re .compile (r'<(?P<type>.+?) at 0x[\da-f]+>' )
125
144
@@ -562,9 +581,9 @@ def generate_signature(
562
581
563
582
if dataclasses .is_dataclass (func ):
564
583
sig = sig .replace (return_annotation = EMPTY )
565
- extract_fn = _extract_dataclass_defaults_and_annotations
584
+ extract_fn = _extract_class_defaults_and_annotations
566
585
else :
567
- extract_fn = _extract_callable_defaults_and_annotations
586
+ extract_fn = _extract_arg_defaults_and_annotations
568
587
569
588
(annotation_source_dict , defaults_source_dict ,
570
589
return_annotation_source ) = extract_fn (func )
@@ -598,44 +617,29 @@ def generate_signature(
598
617
AnnotsDefaultsReturns = Tuple [Dict [str , str ], Dict [str , str ], Any ]
599
618
600
619
601
- def _extract_dataclass_defaults_and_annotations (
602
- func : Type [object ]) -> AnnotsDefaultsReturns :
620
+ def _extract_class_defaults_and_annotations (
621
+ cls : Type [object ]) -> AnnotsDefaultsReturns :
603
622
"""Extract ast defaults and annotations form a dataclass."""
604
- stack = [c for c in func .__mro__ if dataclasses .is_dataclass (c )]
605
-
606
- annotation_source_dict = {}
607
- defaults_source_dict = {}
608
- return_annotation_source = EMPTY
609
-
610
- # Iterate over the classes in reverse order so precedence works.
611
- for cls in reversed (stack ):
612
- func_source = textwrap .dedent (inspect .getsource (cls ))
613
- func_ast = ast .parse (func_source )
614
- # Extract the type annotation from the parsed ast.
615
- ast_visitor = _DataclassDefaultAndAnnotationExtractor ()
616
- ast_visitor .visit (func_ast )
623
+ ast_visitor = _ClassDefaultAndAnnotationExtractor ()
624
+ ast_visitor .extract (cls )
617
625
618
- annotation_source_dict .update (ast_visitor .annotations )
619
- defaults_source_dict .update (ast_visitor .defaults )
620
-
621
- return annotation_source_dict , defaults_source_dict , return_annotation_source
626
+ return (ast_visitor .annotations , ast_visitor .defaults ,
627
+ ast_visitor .return_annotation )
622
628
623
629
624
- def _extract_callable_defaults_and_annotations (
630
+ def _extract_arg_defaults_and_annotations (
625
631
func : Callable [..., Any ]) -> AnnotsDefaultsReturns :
626
632
"""Extract ast defaults and annotations form a standard callable."""
627
633
628
- ast_visitor = _CallableDefaultAndAnnotationExtractor ()
634
+ ast_visitor = _ArgDefaultAndAnnotationExtractor ()
629
635
630
636
annotation_source_dict = {}
631
637
defaults_source_dict = {}
632
638
return_annotation_source = EMPTY
633
639
634
640
try :
635
- func_source = textwrap .dedent (inspect .getsource (func ))
636
- func_ast = ast .parse (func_source )
637
641
# Extract the type annotation from the parsed ast.
638
- ast_visitor .visit ( func_ast )
642
+ ast_visitor .extract ( func )
639
643
except Exception : # pylint: disable=broad-except
640
644
# A wide-variety of errors can be thrown here.
641
645
pass
0 commit comments