Skip to content

Commit 2c2ef3f

Browse files
MarkDaoustcopybara-github
authored andcommitted
Use clearer names for data-structures.
PiperOrigin-RevId: 427784997
1 parent 7f12198 commit 2c2ef3f

File tree

2 files changed

+81
-47
lines changed

2 files changed

+81
-47
lines changed

tools/tensorflow_docs/api_generator/doc_generator_visitor.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
"""A `traverse` visitor for processing documentation."""
1717

1818
import collections
19+
import dataclasses
1920
import inspect
2021

21-
from typing import Any, Dict, List, Optional, Tuple
22+
from typing import Any, Dict, List, Optional, Mapping, Tuple
2223

2324
ApiPath = Tuple[str, ...]
2425

@@ -48,25 +49,32 @@ def maybe_singleton(py_object: Any) -> bool:
4849
return is_immutable_type or (isinstance(py_object, tuple) and py_object == ()) # pylint: disable=g-explicit-bool-comparison
4950

5051

51-
class ApiTreeNode(object):
52-
"""Represents a single API end-point.
52+
@dataclasses.dataclass
53+
class PathTreeNode(object):
54+
"""Represents a path to an object in the API, an object can have many paths.
5355
5456
Attributes:
5557
path: A tuple of strings containing the path to the object from the root
5658
like `('tf', 'losses', 'hinge')`
57-
obj: The python object.
58-
children: A dictionary from short name to `ApiTreeNode`, including the
59-
children nodes.
60-
parent: The parent node.
59+
py_object: The python object.
60+
children: A dictionary from short name to `PathTreeNode`, of this node's
61+
children.
62+
parent: This node's parent. This is a tree, there can only be one.
6163
short_name: The last path component
6264
full_name: All path components joined with "."
6365
"""
66+
path: ApiPath
67+
py_object: Any
68+
parent: Optional['PathTreeNode']
69+
children: Dict[str, 'PathTreeNode'] = dataclasses.field(default_factory=dict)
6470

65-
def __init__(self, path: ApiPath, obj: Any, parent: Optional['ApiTreeNode']):
66-
self.path = path
67-
self.py_object = obj
68-
self.children: Dict[str, 'ApiTreeNode'] = {}
69-
self.parent = parent
71+
def __hash__(self):
72+
return id(self)
73+
74+
def __repr__(self):
75+
return f'{type(self).__name__}({self.full_name})'
76+
77+
__str__ = __repr__
7078

7179
@property
7280
def short_name(self) -> str:
@@ -77,24 +85,42 @@ def full_name(self) -> str:
7785
return '.'.join(self.path)
7886

7987

80-
class ApiTree(object):
81-
"""Represents all api end-points as a tree.
88+
class PathTree(Mapping[ApiPath, PathTreeNode]):
89+
"""An index/tree of all object-paths in the API.
8290
8391
Items must be inserted in order, from root to leaf.
8492
93+
Acts as a Dict[ApiPath, PathTreeNode].
94+
8595
Attributes:
86-
index: A dict, mapping from path tuples to `ApiTreeNode`.
87-
aliases: A dict, mapping from object ids to a list of all `ApiTreeNode` that
88-
refer to the object.
89-
root: The root `ApiTreeNode`
96+
root: The root `PathTreeNode`
9097
"""
9198

9299
def __init__(self):
93-
root = ApiTreeNode(path=(), obj=None, parent=None)
94-
self.index: Dict[ApiPath, ApiTreeNode] = {(): root}
95-
self.aliases: Dict[ApiPath,
96-
List[ApiTreeNode]] = collections.defaultdict(list)
97-
self.root: ApiTreeNode = root
100+
root = PathTreeNode(path=(), py_object=None, parent=None)
101+
self._index: Dict[ApiPath, PathTreeNode] = {(): root}
102+
103+
self.root: PathTreeNode = root
104+
self._nodes_for_id: Dict[int, List[PathTreeNode]] = (
105+
collections.defaultdict(list))
106+
107+
def keys(self):
108+
"""Returns the paths currently contained in the tree."""
109+
return self._index.keys()
110+
111+
def __iter__(self):
112+
return iter(self._index)
113+
114+
def __len__(self):
115+
return len(self._index)
116+
117+
def values(self):
118+
"""Returns the path-nodes for each node currently in the tree."""
119+
return self._index.values()
120+
121+
def items(self):
122+
"""Returns the (path, node) pairs for each node currently in the tree."""
123+
return self._index.items()
98124

99125
def __contains__(self, path: ApiPath) -> bool:
100126
"""Returns `True` if path exists in the tree.
@@ -105,21 +131,24 @@ def __contains__(self, path: ApiPath) -> bool:
105131
Returns:
106132
True if `path` exists in the tree.
107133
"""
108-
return path in self.index
134+
return path in self._index
109135

110-
def __getitem__(self, path: ApiPath) -> ApiTreeNode:
136+
def __getitem__(self, path: ApiPath) -> PathTreeNode:
111137
"""Fetch an item from the tree.
112138
113139
Args:
114140
path: A tuple of strings, the api path to the object.
115141
116142
Returns:
117-
An `ApiTreeNode`.
143+
A `PathTreeNode`.
118144
119145
Raises:
120146
KeyError: If no node can be found at that path.
121147
"""
122-
return self.index[path]
148+
return self._index[path]
149+
150+
def get(self, path: ApiPath, default=None):
151+
return self._index.get(path, default)
123152

124153
def __setitem__(self, path: ApiPath, obj: Any):
125154
"""Add an object to the tree.
@@ -129,18 +158,21 @@ def __setitem__(self, path: ApiPath, obj: Any):
129158
obj: The python object.
130159
"""
131160
parent_path = path[:-1]
132-
parent = self.index[parent_path]
161+
parent = self._index[parent_path]
133162

134-
node = ApiTreeNode(path=path, obj=obj, parent=parent)
163+
node = PathTreeNode(path=path, py_object=obj, parent=parent)
135164

136-
self.index[path] = node
165+
self._index[path] = node
137166
if not maybe_singleton(obj):
138167
# We cannot use the duplicate mechanism for some constants, since e.g.,
139168
# id(c1) == id(c2) with c1=1, c2=1. This isn't problematic since constants
140169
# have no usable docstring and won't be documented automatically.
141-
self.aliases[id(obj)].append(node) # pytype: disable=unsupported-operands # attribute-variable-annotations
170+
self.nodes_for_obj(obj).append(node)
142171
parent.children[node.short_name] = node
143172

173+
def nodes_for_obj(self, py_object) -> List[PathTreeNode]:
174+
return self._nodes_for_id[id(py_object)]
175+
144176

145177
class DocGeneratorVisitor(object):
146178
"""A visitor that generates docs for a python object when __call__ed."""
@@ -174,7 +206,7 @@ def __init__(self):
174206
self._duplicates: Dict[str, List[str]] = None
175207
self._duplicate_of: Dict[str, str] = None
176208

177-
self._api_tree = ApiTree()
209+
self._path_tree = PathTree()
178210

179211
@property
180212
def index(self):
@@ -270,15 +302,16 @@ class or module.
270302
parent_name = '.'.join(parent_path)
271303
self._index[parent_name] = parent
272304
self._tree[parent_name] = []
273-
if parent_path not in self._api_tree:
274-
self._api_tree[parent_path] = parent
305+
if parent_path not in self._path_tree:
306+
self._path_tree[parent_path] = parent
275307

276308
if not (inspect.ismodule(parent) or inspect.isclass(parent)):
277-
raise RuntimeError('Unexpected type in visitor -- '
278-
f'{parent_name}: {parent!r}')
309+
raise TypeError('Unexpected type in visitor -- '
310+
f'{parent_name}: {parent!r}')
279311

280312
for name, child in children:
281-
self._api_tree[parent_path + (name,)] = child
313+
child_path = parent_path + (name,)
314+
self._path_tree[child_path] = child
282315

283316
full_name = '.'.join([parent_name, name]) if parent_name else name
284317
self._index[full_name] = child
@@ -379,7 +412,7 @@ def _maybe_find_duplicates(self):
379412
# symbol (incl. itself).
380413
duplicates = {}
381414

382-
for path, node in self._api_tree.index.items():
415+
for path, node in self._path_tree.items():
383416
if not path:
384417
continue
385418
full_name = node.full_name
@@ -388,7 +421,8 @@ def _maybe_find_duplicates(self):
388421
if full_name in duplicates:
389422
continue
390423

391-
aliases = self._api_tree.aliases[object_id]
424+
aliases = self._path_tree.nodes_for_obj(py_object)
425+
# maybe_singleton types can't be looked up by object.
392426
if not aliases:
393427
aliases = [node]
394428

tools/tensorflow_docs/api_generator/doc_generator_visitor_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_call_class(self):
6868

6969
def test_call_raises(self):
7070
visitor = doc_generator_visitor.DocGeneratorVisitor()
71-
with self.assertRaises(RuntimeError):
71+
with self.assertRaises(TypeError):
7272
visitor(('non_class_or_module',), 'non_class_or_module_object', [])
7373

7474
def test_duplicates_module_class_depth(self):
@@ -255,13 +255,13 @@ class Parent(object):
255255
}, visitor.reverse_index)
256256

257257

258-
class ApiTreeTest(absltest.TestCase):
258+
class PathTreeTest(absltest.TestCase):
259259

260260
def test_contains(self):
261261
tf = argparse.Namespace()
262262
tf.sub = argparse.Namespace()
263263

264-
tree = doc_generator_visitor.ApiTree()
264+
tree = doc_generator_visitor.PathTree()
265265
tree[('tf',)] = tf
266266
tree[('tf', 'sub')] = tf.sub
267267

@@ -273,7 +273,7 @@ def test_node_insertion(self):
273273
tf.sub = argparse.Namespace()
274274
tf.sub.object = object()
275275

276-
tree = doc_generator_visitor.ApiTree()
276+
tree = doc_generator_visitor.PathTree()
277277
tree[('tf',)] = tf
278278
tree[('tf', 'sub')] = tf.sub
279279
tree[('tf', 'sub', 'thing')] = tf.sub.object
@@ -292,15 +292,15 @@ def test_duplicate(self):
292292
tf.sub2 = argparse.Namespace()
293293
tf.sub2.thing = tf.sub.thing
294294

295-
tree = doc_generator_visitor.ApiTree()
295+
tree = doc_generator_visitor.PathTree()
296296
tree[('tf',)] = tf
297297
tree[('tf', 'sub')] = tf.sub
298298
tree[('tf', 'sub', 'thing')] = tf.sub.thing
299299
tree[('tf', 'sub2')] = tf.sub2
300300
tree[('tf', 'sub2', 'thing')] = tf.sub2.thing
301301

302302
self.assertCountEqual(
303-
tree.aliases[id(tf.sub.thing)],
303+
tree.nodes_for_obj(tf.sub.thing),
304304
[tree[('tf', 'sub', 'thing')], tree[('tf', 'sub2', 'thing')]])
305305

306306
def test_duplicate_singleton(self):
@@ -310,14 +310,14 @@ def test_duplicate_singleton(self):
310310
tf.sub2 = argparse.Namespace()
311311
tf.sub2.thing = tf.sub.thing
312312

313-
tree = doc_generator_visitor.ApiTree()
313+
tree = doc_generator_visitor.PathTree()
314314
tree[('tf',)] = tf
315315
tree[('tf', 'sub')] = tf.sub
316316
tree[('tf', 'sub', 'thing')] = tf.sub.thing
317317
tree[('tf', 'sub2')] = tf.sub2
318318
tree[('tf', 'sub2', 'thing')] = tf.sub2.thing
319319

320-
self.assertEmpty(tree.aliases[tf.sub.thing], [])
320+
self.assertEmpty(tree.nodes_for_obj(tf.sub.thing), [])
321321

322322

323323
if __name__ == '__main__':

0 commit comments

Comments
 (0)