Skip to content

Commit 65bd57a

Browse files
MarkDaoustcopybara-github
authored andcommitted
Simplify test object construction.
Manually building the indices is complicated and error prone. Never do it manually. PiperOrigin-RevId: 438277327
1 parent 75cf485 commit 65bd57a

File tree

6 files changed

+290
-530
lines changed

6 files changed

+290
-530
lines changed

tools/tensorflow_docs/api_generator/doc_generator_visitor.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __init__(self):
205205
self._duplicates: Dict[str, List[str]] = None
206206
self._duplicate_of: Dict[str, str] = None
207207

208-
self._path_tree = PathTree()
208+
self.path_tree = PathTree()
209209

210210
@property
211211
def index(self):
@@ -242,7 +242,6 @@ def reverse_index(self):
242242
Returns:
243243
The `id(object)` to full name map.
244244
"""
245-
self._maybe_find_duplicates()
246245
return self._reverse_index
247246

248247
@property
@@ -256,7 +255,6 @@ def duplicate_of(self):
256255
Returns:
257256
The map from duplicate name to preferred name.
258257
"""
259-
self._maybe_find_duplicates()
260258
return self._duplicate_of
261259

262260
@property
@@ -273,7 +271,6 @@ def duplicates(self):
273271
Returns:
274272
The map from main name to list of all duplicate names.
275273
"""
276-
self._maybe_find_duplicates()
277274
return self._duplicates
278275

279276
def __call__(self, parent_path, parent, children):
@@ -301,16 +298,16 @@ class or module.
301298
parent_name = '.'.join(parent_path)
302299
self._index[parent_name] = parent
303300
self._tree[parent_name] = []
304-
if parent_path not in self._path_tree:
305-
self._path_tree[parent_path] = parent
301+
if parent_path not in self.path_tree:
302+
self.path_tree[parent_path] = parent
306303

307304
if not (inspect.ismodule(parent) or inspect.isclass(parent)):
308305
raise TypeError('Unexpected type in visitor -- '
309306
f'{parent_name}: {parent!r}')
310307

311308
for name, child in children:
312309
child_path = parent_path + (name,)
313-
self._path_tree[child_path] = child
310+
self.path_tree[child_path] = child
314311

315312
full_name = '.'.join([parent_name, name]) if parent_name else name
316313
self._index[full_name] = child
@@ -379,7 +376,7 @@ def _score_name(self, name):
379376
return (defining_class_score, experimental_score, keras_score,
380377
module_length_score, name)
381378

382-
def _maybe_find_duplicates(self):
379+
def build(self):
383380
"""Compute data structures containing information about duplicates.
384381
385382
Find duplicates in `index` and decide on one to be the "main" name.
@@ -411,7 +408,7 @@ def _maybe_find_duplicates(self):
411408
# symbol (incl. itself).
412409
duplicates = {}
413410

414-
for path, node in self._path_tree.items():
411+
for path, node in self.path_tree.items():
415412
if not path:
416413
continue
417414
full_name = node.full_name
@@ -420,7 +417,7 @@ def _maybe_find_duplicates(self):
420417
if full_name in duplicates:
421418
continue
422419

423-
aliases = self._path_tree.nodes_for_obj(py_object)
420+
aliases = self.path_tree.nodes_for_obj(py_object)
424421
# maybe_singleton types can't be looked up by object.
425422
if not aliases:
426423
aliases = [node]

tools/tensorflow_docs/api_generator/generate_lib.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,6 @@
3939

4040
import yaml
4141

42-
try:
43-
# TODO(markdaoust) delete this when the warning is in a stable release.
44-
_estimator = importlib.import_module(
45-
'tensorflow_estimator.python.estimator.estimator')
46-
47-
if doc_controls.get_inheritable_header(_estimator.Estimator) is None:
48-
_add_header = doc_controls.inheritable_header("""\
49-
Warning: Estimators are not recommended for new code. Estimators run
50-
`v1.Session`-style code which is more difficult to write correctly, and
51-
can behave unexpectedly, especially when combined with TF 2 code.
52-
Estimators do fall under our
53-
[compatibility guarantees](https://tensorflow.org/guide/versions), but
54-
will receive no fixes other than security vulnerabilities. See the
55-
[migration guide](https://tensorflow.org/guide/migrate) for details.
56-
""")
57-
_add_header(_estimator.Estimator)
58-
except ImportError:
59-
pass
60-
61-
6242
# Used to add a collections.OrderedDict representer to yaml so that the
6343
# dump doesn't contain !!OrderedDict yaml tags.
6444
# Reference: https://stackoverflow.com/a/21048064
@@ -687,6 +667,7 @@ def extract(py_modules,
687667

688668
traverse.traverse(py_module, visitors, short_name)
689669

670+
accumulator.build()
690671
return accumulator
691672

692673

@@ -818,12 +799,17 @@ def run_extraction(self):
818799
819800
Returns:
820801
"""
821-
return extract(
802+
visitor = extract(
822803
py_modules=self._py_modules,
823804
base_dir=self._base_dir,
824805
private_map=self._private_map,
825806
visitor_cls=self._visitor_cls,
826807
callbacks=self._callbacks)
808+
reference_resolver = self.make_reference_resolver(visitor)
809+
810+
# Write the api docs.
811+
parser_config = self.make_parser_config(visitor, reference_resolver)
812+
return parser_config
827813

828814
def build(self, output_dir):
829815
"""Build all the docs.
@@ -838,11 +824,7 @@ def build(self, output_dir):
838824
workdir = pathlib.Path(tempfile.mkdtemp())
839825

840826
# Extract the python api from the _py_modules
841-
visitor = self.run_extraction()
842-
reference_resolver = self.make_reference_resolver(visitor)
843-
844-
# Write the api docs.
845-
parser_config = self.make_parser_config(visitor, reference_resolver)
827+
parser_config = self.run_extraction()
846828
work_py_dir = workdir / 'api_docs/python'
847829
write_docs(
848830
output_dir=str(work_py_dir),
@@ -859,7 +841,7 @@ def build(self, output_dir):
859841
)
860842

861843
if self.api_cache:
862-
reference_resolver.to_json_file(
844+
parser_config.reference_resolver.to_json_file(
863845
str(work_py_dir / self._short_name.replace('.', '/') /
864846
'_api_cache.json'))
865847

tools/tensorflow_docs/api_generator/generate_lib_test.py

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sys
2020
import tempfile
2121
import textwrap
22+
import types
2223

2324
from absl import flags
2425
from absl.testing import absltest
@@ -77,59 +78,21 @@ def setUp(self):
7778
def get_test_objects(self):
7879
# These are all mutable objects, so rebuild them for each test.
7980
# Don't cache the objects.
80-
module = sys.modules[__name__]
81-
82-
index = {
83-
'tf':
84-
sys, # Can be any module, this test doesn't care about content.
85-
'tf.TestModule':
86-
module,
87-
'tf.test_function':
88-
test_function,
89-
'tf.TestModule.test_function':
90-
test_function,
91-
'tf.TestModule.TestClass':
92-
TestClass,
93-
'tf.TestModule.TestClass.ChildClass':
94-
TestClass.ChildClass,
95-
'tf.TestModule.TestClass.ChildClass.GrandChildClass':
96-
TestClass.ChildClass.GrandChildClass,
97-
}
98-
99-
tree = {
100-
'tf': ['TestModule', 'test_function'],
101-
'tf.TestModule': ['test_function', 'TestClass'],
102-
'tf.TestModule.TestClass': ['ChildClass'],
103-
'tf.TestModule.TestClass.ChildClass': ['GrandChildClass'],
104-
'tf.TestModule.TestClass.ChildClass.GrandChildClass': []
105-
}
106-
107-
duplicate_of = {'tf.test_function': 'tf.TestModule.test_function'}
108-
109-
duplicates = {
110-
'tf.TestModule.test_function': [
111-
'tf.test_function', 'tf.TestModule.test_function'
112-
]
113-
}
114-
115-
base_dir = os.path.dirname(__file__)
116-
117-
visitor = DummyVisitor(index, duplicate_of)
118-
119-
reference_resolver = reference_resolver_lib.ReferenceResolver.from_visitor(
120-
visitor=visitor, py_module_names=['tf'], link_prefix='api_docs/python')
121-
122-
parser_config = config.ParserConfig(
123-
reference_resolver=reference_resolver,
124-
duplicates=duplicates,
125-
duplicate_of=duplicate_of,
126-
tree=tree,
127-
index=index,
128-
reverse_index={},
129-
base_dir=base_dir,
130-
code_url_prefix='/')
131-
132-
return reference_resolver, parser_config
81+
tf = types.ModuleType('tf')
82+
tf.__file__ = __file__
83+
tf.TestModule = types.ModuleType('module')
84+
tf.test_function = test_function
85+
tf.TestModule.test_function = test_function
86+
tf.TestModule.TestClass = TestClass
87+
88+
generator = generate_lib.DocGenerator(
89+
root_title='TensorFlow',
90+
py_modules=[('tf', tf)],
91+
code_url_prefix='https://tensorflow.org/')
92+
93+
parser_config = generator.run_extraction()
94+
95+
return parser_config.reference_resolver, parser_config
13396

13497
def test_write(self):
13598
_, parser_config = self.get_test_objects()

0 commit comments

Comments
 (0)