Skip to content

Commit 2da2058

Browse files
MarkDaoustcopybara-github
authored andcommitted
Upgrade the class-defaults extractor to work on without annotations.
Now this can work on any class (maybe modules too), not only `dataclasses` PiperOrigin-RevId: 436961154
1 parent ba5b597 commit 2da2058

File tree

2 files changed

+56
-30
lines changed

2 files changed

+56
-30
lines changed

tools/tensorflow_docs/api_generator/signature.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ def _preprocess_default(self, val: ast.AST) -> str:
5454
text_default_val = self._PAREN_NUMBER_RE.sub('\\1', text_default_val)
5555
return text_default_val
5656

57+
def extract(self, obj: Any):
58+
obj_source = textwrap.dedent(inspect.getsource(obj))
59+
obj_ast = ast.parse(obj_source)
60+
self.visit(obj_ast)
5761

58-
class _CallableDefaultAndAnnotationExtractor(_BaseDefaultAndAnnotationExtractor
59-
):
62+
63+
class _ArgDefaultAndAnnotationExtractor(_BaseDefaultAndAnnotationExtractor):
6064
"""Extracts the type annotations by parsing the AST of a function."""
6165

6266
def visit_FunctionDef(self, node) -> None: # pylint: disable=invalid-name
@@ -95,8 +99,7 @@ def visit_FunctionDef(self, node) -> None: # pylint: disable=invalid-name
9599
self.defaults[kwarg.arg] = text_default_val
96100

97101

98-
class _DataclassDefaultAndAnnotationExtractor(_BaseDefaultAndAnnotationExtractor
99-
):
102+
class _ClassDefaultAndAnnotationExtractor(_BaseDefaultAndAnnotationExtractor):
100103
"""Extracts the type annotations by parsing the AST of a dataclass."""
101104

102105
def __init__(self):
@@ -111,6 +114,8 @@ def visit_ClassDef(self, node) -> None: # pylint: disable=invalid-name
111114
for sub in node.body:
112115
if isinstance(sub, ast.AnnAssign):
113116
self.visit_AnnAssign(sub)
117+
elif isinstance(sub, ast.Assign):
118+
self.visit_Assign(sub)
114119

115120
def visit_AnnAssign(self, node) -> None: # pylint: disable=invalid-name
116121
"""Vists an assignment with a type annotation. Dataclasses is an example."""
@@ -120,6 +125,20 @@ def visit_AnnAssign(self, node) -> None: # pylint: disable=invalid-name
120125
if node.value is not None:
121126
self.defaults[arg] = self._preprocess_default(node.value)
122127

128+
def visit_Assign(self, node) -> None: # pylint: disable=invalid-name
129+
"""Vists an assignment with a type annotation. Dataclasses is an example."""
130+
names = [_source_from_ast(t) for t in node.targets]
131+
if node.value is not None:
132+
val = self._preprocess_default(node.value)
133+
for name in names:
134+
self.defaults[name] = val
135+
136+
def extract(self, cls):
137+
# Iterate over the classes in reverse order so each class overwrites it's
138+
# parents. Skip `object`.
139+
for cls in reversed(cls.__mro__[:-1]):
140+
super().extract(cls)
141+
123142

124143
_OBJECT_MEMORY_ADDRESS_RE = re.compile(r'<(?P<type>.+?) at 0x[\da-f]+>')
125144

@@ -562,9 +581,9 @@ def generate_signature(
562581

563582
if dataclasses.is_dataclass(func):
564583
sig = sig.replace(return_annotation=EMPTY)
565-
extract_fn = _extract_dataclass_defaults_and_annotations
584+
extract_fn = _extract_class_defaults_and_annotations
566585
else:
567-
extract_fn = _extract_callable_defaults_and_annotations
586+
extract_fn = _extract_arg_defaults_and_annotations
568587

569588
(annotation_source_dict, defaults_source_dict,
570589
return_annotation_source) = extract_fn(func)
@@ -598,44 +617,29 @@ def generate_signature(
598617
AnnotsDefaultsReturns = Tuple[Dict[str, str], Dict[str, str], Any]
599618

600619

601-
def _extract_dataclass_defaults_and_annotations(
602-
func: Type[object]) -> AnnotsDefaultsReturns:
620+
def _extract_class_defaults_and_annotations(
621+
cls: Type[object]) -> AnnotsDefaultsReturns:
603622
"""Extract ast defaults and annotations form a dataclass."""
604-
stack = [c for c in func.__mro__ if dataclasses.is_dataclass(c)]
605-
606-
annotation_source_dict = {}
607-
defaults_source_dict = {}
608-
return_annotation_source = EMPTY
609-
610-
# Iterate over the classes in reverse order so precedence works.
611-
for cls in reversed(stack):
612-
func_source = textwrap.dedent(inspect.getsource(cls))
613-
func_ast = ast.parse(func_source)
614-
# Extract the type annotation from the parsed ast.
615-
ast_visitor = _DataclassDefaultAndAnnotationExtractor()
616-
ast_visitor.visit(func_ast)
623+
ast_visitor = _ClassDefaultAndAnnotationExtractor()
624+
ast_visitor.extract(cls)
617625

618-
annotation_source_dict.update(ast_visitor.annotations)
619-
defaults_source_dict.update(ast_visitor.defaults)
620-
621-
return annotation_source_dict, defaults_source_dict, return_annotation_source
626+
return (ast_visitor.annotations, ast_visitor.defaults,
627+
ast_visitor.return_annotation)
622628

623629

624-
def _extract_callable_defaults_and_annotations(
630+
def _extract_arg_defaults_and_annotations(
625631
func: Callable[..., Any]) -> AnnotsDefaultsReturns:
626632
"""Extract ast defaults and annotations form a standard callable."""
627633

628-
ast_visitor = _CallableDefaultAndAnnotationExtractor()
634+
ast_visitor = _ArgDefaultAndAnnotationExtractor()
629635

630636
annotation_source_dict = {}
631637
defaults_source_dict = {}
632638
return_annotation_source = EMPTY
633639

634640
try:
635-
func_source = textwrap.dedent(inspect.getsource(func))
636-
func_ast = ast.parse(func_source)
637641
# Extract the type annotation from the parsed ast.
638-
ast_visitor.visit(func_ast)
642+
ast_visitor.extract(func)
639643
except Exception: # pylint: disable=broad-except
640644
# A wide-variety of errors can be thrown here.
641645
pass

tools/tensorflow_docs/api_generator/signature_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,28 @@ class Child(Parent):
408408
)""")
409409
self.assertEqual(expected, str(sig))
410410

411+
def test_extract_non_annotated(self):
412+
413+
const = 1234
414+
415+
class A:
416+
a = 60 * 60
417+
b = 1 / 9
418+
419+
class B(A):
420+
b = 2 / 9
421+
c = const
422+
423+
ast_extractor = signature._ClassDefaultAndAnnotationExtractor()
424+
ast_extractor.extract(B)
425+
426+
self.assertEqual({
427+
'a': '(60 * 60)',
428+
'b': '(2 / 9)',
429+
'c': 'const'
430+
}, ast_extractor.defaults)
431+
432+
411433
def test_vararg_before_kwargonly_consistent_order(self):
412434

413435
def my_fun(*args, a=1, **kwargs): # pylint: disable=unused-argument

0 commit comments

Comments
 (0)