Skip to content

Commit f5619f5

Browse files
Merge pull request #2199 from MarkDaoust:proto
PiperOrigin-RevId: 517013282
2 parents e7f81c2 + 03b59e4 commit f5619f5

File tree

3 files changed

+202
-171
lines changed

3 files changed

+202
-171
lines changed

tools/tensorflow_docs/api_generator/doc_generator_visitor_test.py

Lines changed: 144 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tests for tools.docs.doc_generator_visitor."""
1616

17+
import dataclasses
1718
import io
1819
import os
1920
import textwrap
@@ -36,6 +37,17 @@ def __call__(self, parent_name, parent, children):
3637
return super(NoDunderVisitor, self).__call__(parent_name, parent, children)
3738

3839

40+
class TestDocGenerator(generate_lib.DocGenerator):
41+
42+
def __init__(self, py_modules):
43+
kwargs = {}
44+
kwargs['py_modules'] = py_modules
45+
kwargs['root_title'] = 'TensorFlow'
46+
kwargs['visitor_cls'] = NoDunderVisitor
47+
kwargs['code_url_prefix'] = '/'
48+
super().__init__(**kwargs)
49+
50+
3951
class DocGeneratorVisitorTest(absltest.TestCase):
4052

4153
def test_call_module(self):
@@ -91,39 +103,41 @@ class Nested(object):
91103
tf.submodule = types.ModuleType('submodule')
92104
tf.submodule.Parent = Parent
93105

94-
visitor = generate_lib.extract(
95-
[('tf', tf)],
96-
base_dir=os.path.dirname(tf.__file__),
97-
private_map={},
98-
visitor_cls=NoDunderVisitor)
106+
config = TestDocGenerator([('tf', tf)]).run_extraction()
99107

100108
self.assertEqual(
101109
{
102-
'tf.submodule.Parent':
103-
sorted([
104-
'tf.Parent',
105-
'tf.submodule.Parent',
106-
]),
107-
'tf.submodule.Parent.Nested':
108-
sorted([
109-
'tf.Parent.Nested',
110-
'tf.submodule.Parent.Nested',
111-
]),
110+
'tf.submodule.Parent': sorted([
111+
'tf.Parent',
112+
'tf.submodule.Parent',
113+
]),
114+
'tf.submodule.Parent.Nested': sorted([
115+
'tf.Parent.Nested',
116+
'tf.submodule.Parent.Nested',
117+
]),
112118
'tf': ['tf'],
113-
'tf.submodule': ['tf.submodule']
114-
}, visitor.duplicates)
119+
'tf.submodule': ['tf.submodule'],
120+
},
121+
config.duplicates,
122+
)
115123

116-
self.assertEqual({
117-
'tf.Parent.Nested': 'tf.submodule.Parent.Nested',
118-
'tf.Parent': 'tf.submodule.Parent',
119-
}, visitor.duplicate_of)
124+
self.assertEqual(
125+
{
126+
'tf.Parent.Nested': 'tf.submodule.Parent.Nested',
127+
'tf.Parent': 'tf.submodule.Parent',
128+
},
129+
config.duplicate_of,
130+
)
120131

121-
self.assertEqual({
122-
id(Parent): 'tf.submodule.Parent',
123-
id(Parent.Nested): 'tf.submodule.Parent.Nested',
124-
id(tf): 'tf',
125-
id(tf.submodule): 'tf.submodule',
126-
}, visitor.reverse_index)
132+
self.assertEqual(
133+
{
134+
id(Parent): 'tf.submodule.Parent',
135+
id(Parent.Nested): 'tf.submodule.Parent.Nested',
136+
id(tf): 'tf',
137+
id(tf.submodule): 'tf.submodule',
138+
},
139+
config.reverse_index,
140+
)
127141

128142
def test_duplicates_contrib(self):
129143

@@ -137,25 +151,29 @@ class Parent(object):
137151
tf.contrib.Parent = Parent
138152
tf.submodule.Parent = Parent
139153

140-
visitor = generate_lib.extract(
141-
[('tf', tf)],
142-
base_dir=os.path.dirname(tf.__file__),
143-
private_map={},
144-
visitor_cls=NoDunderVisitor)
154+
config = TestDocGenerator([('tf', tf)]).run_extraction()
145155

146-
self.assertCountEqual(['tf.contrib.Parent', 'tf.submodule.Parent'],
147-
visitor.duplicates['tf.submodule.Parent'])
156+
self.assertCountEqual(
157+
['tf.contrib.Parent', 'tf.submodule.Parent'],
158+
config.duplicates['tf.submodule.Parent'],
159+
)
148160

149-
self.assertEqual({
150-
'tf.contrib.Parent': 'tf.submodule.Parent',
151-
}, visitor.duplicate_of)
161+
self.assertEqual(
162+
{
163+
'tf.contrib.Parent': 'tf.submodule.Parent',
164+
},
165+
config.duplicate_of,
166+
)
152167

153-
self.assertEqual({
154-
id(tf): 'tf',
155-
id(tf.submodule): 'tf.submodule',
156-
id(Parent): 'tf.submodule.Parent',
157-
id(tf.contrib): 'tf.contrib',
158-
}, visitor.reverse_index)
168+
self.assertEqual(
169+
{
170+
id(tf): 'tf',
171+
id(tf.submodule): 'tf.submodule',
172+
id(Parent): 'tf.submodule.Parent',
173+
id(tf.contrib): 'tf.contrib',
174+
},
175+
config.reverse_index,
176+
)
159177

160178
def test_duplicates_defining_class(self):
161179

@@ -170,25 +188,28 @@ class Child(Parent):
170188
tf.Parent = Parent
171189
tf.Child = Child
172190

173-
visitor = generate_lib.extract(
174-
[('tf', tf)],
175-
base_dir=os.path.dirname(tf.__file__),
176-
private_map={},
177-
visitor_cls=NoDunderVisitor)
191+
config = TestDocGenerator([('tf', tf)]).run_extraction()
178192

179-
self.assertCountEqual(['tf.Parent.obj1', 'tf.Child.obj1'],
180-
visitor.duplicates['tf.Parent.obj1'])
193+
self.assertCountEqual(
194+
['tf.Parent.obj1', 'tf.Child.obj1'], config.duplicates['tf.Parent.obj1']
195+
)
181196

182-
self.assertEqual({
183-
'tf.Child.obj1': 'tf.Parent.obj1',
184-
}, visitor.duplicate_of)
197+
self.assertEqual(
198+
{
199+
'tf.Child.obj1': 'tf.Parent.obj1',
200+
},
201+
config.duplicate_of,
202+
)
185203

186-
self.assertEqual({
187-
id(tf): 'tf',
188-
id(Parent): 'tf.Parent',
189-
id(Child): 'tf.Child',
190-
id(Parent.obj1): 'tf.Parent.obj1',
191-
}, visitor.reverse_index)
204+
self.assertEqual(
205+
{
206+
id(tf): 'tf',
207+
id(Parent): 'tf.Parent',
208+
id(Child): 'tf.Child',
209+
id(Parent.obj1): 'tf.Parent.obj1',
210+
},
211+
config.reverse_index,
212+
)
192213

193214
def test_duplicates_module_depth(self):
194215

@@ -202,25 +223,26 @@ class Parent(object):
202223
tf.Parent = Parent
203224
tf.submodule.submodule2.Parent = Parent
204225

205-
visitor = generate_lib.extract(
206-
[('tf', tf)],
207-
base_dir=os.path.dirname(tf.__file__),
208-
private_map={},
209-
visitor_cls=NoDunderVisitor)
226+
config = TestDocGenerator([('tf', tf)]).run_extraction()
210227

211-
self.assertCountEqual(['tf.Parent', 'tf.submodule.submodule2.Parent'],
212-
visitor.duplicates['tf.Parent'])
228+
self.assertCountEqual(
229+
['tf.Parent', 'tf.submodule.submodule2.Parent'],
230+
config.duplicates['tf.Parent'],
231+
)
213232

214-
self.assertEqual({
215-
'tf.submodule.submodule2.Parent': 'tf.Parent'
216-
}, visitor.duplicate_of)
233+
self.assertEqual(
234+
{'tf.submodule.submodule2.Parent': 'tf.Parent'}, config.duplicate_of
235+
)
217236

218-
self.assertEqual({
219-
id(tf): 'tf',
220-
id(tf.submodule): 'tf.submodule',
221-
id(tf.submodule.submodule2): 'tf.submodule.submodule2',
222-
id(Parent): 'tf.Parent',
223-
}, visitor.reverse_index)
237+
self.assertEqual(
238+
{
239+
id(tf): 'tf',
240+
id(tf.submodule): 'tf.submodule',
241+
id(tf.submodule.submodule2): 'tf.submodule.submodule2',
242+
id(Parent): 'tf.Parent',
243+
},
244+
config.reverse_index,
245+
)
224246

225247
def test_duplicates_name(self):
226248

@@ -234,27 +256,32 @@ class Parent(object):
234256
tf.submodule = types.ModuleType('submodule')
235257
tf.submodule.Parent = Parent
236258

237-
visitor = generate_lib.extract(
238-
[('tf', tf)],
239-
base_dir=os.path.dirname(tf.__file__),
240-
private_map={},
241-
visitor_cls=NoDunderVisitor)
259+
config = TestDocGenerator([('tf', tf)]).run_extraction()
260+
242261
self.assertEqual(
243262
sorted([
244263
'tf.submodule.Parent.obj1',
245264
'tf.submodule.Parent.obj2',
246-
]), visitor.duplicates['tf.submodule.Parent.obj1'])
265+
]),
266+
config.duplicates['tf.submodule.Parent.obj1'],
267+
)
247268

248-
self.assertEqual({
249-
'tf.submodule.Parent.obj2': 'tf.submodule.Parent.obj1',
250-
}, visitor.duplicate_of)
269+
self.assertEqual(
270+
{
271+
'tf.submodule.Parent.obj2': 'tf.submodule.Parent.obj1',
272+
},
273+
config.duplicate_of,
274+
)
251275

252-
self.assertEqual({
253-
id(tf): 'tf',
254-
id(tf.submodule): 'tf.submodule',
255-
id(Parent): 'tf.submodule.Parent',
256-
id(Parent.obj1): 'tf.submodule.Parent.obj1',
257-
}, visitor.reverse_index)
276+
self.assertEqual(
277+
{
278+
id(tf): 'tf',
279+
id(tf.submodule): 'tf.submodule',
280+
id(Parent): 'tf.submodule.Parent',
281+
id(Parent.obj1): 'tf.submodule.Parent.obj1',
282+
},
283+
config.reverse_index,
284+
)
258285

259286
def test_handles_duplicate_classmethods(self):
260287

@@ -270,12 +297,9 @@ def from_value(cls, value):
270297
tf.sub = types.ModuleType('sub')
271298
tf.sub.MyClass = MyClass
272299

273-
visitor = generate_lib.extract([('tf', tf)],
274-
base_dir=os.path.dirname(tf.__file__),
275-
private_map={},
276-
visitor_cls=NoDunderVisitor)
300+
config = TestDocGenerator([('tf', tf)]).run_extraction()
277301

278-
paths = ['.'.join(p) for p in visitor.path_tree.keys()]
302+
paths = ['.'.join(p) for p in config.path_tree.keys()]
279303

280304
expected = [
281305
'',
@@ -288,7 +312,7 @@ def from_value(cls, value):
288312
]
289313
self.assertCountEqual(expected, paths)
290314

291-
apis = [node.full_name for node in visitor.api_tree.iter_nodes()]
315+
apis = [node.full_name for node in config.api_tree.iter_nodes()]
292316
expected = [
293317
'tf',
294318
'tf.sub',
@@ -297,10 +321,14 @@ def from_value(cls, value):
297321
]
298322
self.assertCountEqual(expected, apis)
299323

300-
self.assertIs(visitor.api_tree[('tf', 'MyClass')],
301-
visitor.api_tree[('tf', 'sub', 'MyClass')])
302-
self.assertIs(visitor.api_tree[('tf', 'MyClass', 'from_value')],
303-
visitor.api_tree[('tf', 'sub', 'MyClass', 'from_value')])
324+
self.assertIs(
325+
config.api_tree[('tf', 'MyClass')],
326+
config.api_tree[('tf', 'sub', 'MyClass')],
327+
)
328+
self.assertIs(
329+
config.api_tree[('tf', 'MyClass', 'from_value')],
330+
config.api_tree[('tf', 'sub', 'MyClass', 'from_value')],
331+
)
304332

305333
def test_handles_duplicate_singleton_attributes(self):
306334

@@ -313,12 +341,9 @@ class MyClass:
313341
tf.sub = types.ModuleType('sub')
314342
tf.sub.MyClass = MyClass
315343

316-
visitor = generate_lib.extract([('tf', tf)],
317-
base_dir=os.path.dirname(tf.__file__),
318-
private_map={},
319-
visitor_cls=NoDunderVisitor)
344+
config = TestDocGenerator([('tf', tf)]).run_extraction()
320345

321-
paths = ['.'.join(p) for p in visitor.path_tree.keys()]
346+
paths = ['.'.join(p) for p in config.path_tree.keys()]
322347

323348
expected = [
324349
'',
@@ -331,7 +356,7 @@ class MyClass:
331356
]
332357
self.assertCountEqual(expected, paths)
333358

334-
apis = ['.'.join(p) for p in visitor.api_tree.keys()]
359+
apis = ['.'.join(p) for p in config.api_tree.keys()]
335360
expected = [
336361
'',
337362
'tf',
@@ -343,10 +368,14 @@ class MyClass:
343368
]
344369
self.assertCountEqual(expected, apis)
345370

346-
self.assertIs(visitor.api_tree[('tf', 'MyClass')],
347-
visitor.api_tree[('tf', 'sub', 'MyClass')])
348-
self.assertIs(visitor.api_tree[('tf', 'MyClass', 'simple')],
349-
visitor.api_tree[('tf', 'sub', 'MyClass', 'simple')])
371+
self.assertIs(
372+
config.api_tree[('tf', 'MyClass')],
373+
config.api_tree[('tf', 'sub', 'MyClass')],
374+
)
375+
self.assertIs(
376+
config.api_tree[('tf', 'MyClass', 'simple')],
377+
config.api_tree[('tf', 'sub', 'MyClass', 'simple')],
378+
)
350379

351380

352381
class PathTreeTest(absltest.TestCase):
@@ -540,10 +569,11 @@ def test_from_path_tree(self):
540569
def test_api_tree_toc_integration(self):
541570
tf = self._make_fake_module()
542571

543-
visitor = generate_lib.extract([('tf', tf)],
544-
base_dir=os.path.dirname(tf.__file__),
545-
private_map={},
546-
visitor_cls=NoDunderVisitor)
572+
gen = TestDocGenerator([('tf', tf)])
573+
filters = gen.make_default_filters()
574+
visitor = generate_lib.extract(
575+
[('tf', tf)], filters=filters, visitor_cls=NoDunderVisitor
576+
)
547577

548578
api_tree = doc_generator_visitor.ApiTree.from_path_tree(
549579
visitor.path_tree, visitor._score_name)

0 commit comments

Comments
 (0)