Skip to content

Commit 3a6a7a3

Browse files
MarkDaoustcopybara-github
authored andcommitted
normalize the public_api filters
More testabole -> more tests. PiperOrigin-RevId: 440850613
1 parent 44a3c7b commit 3a6a7a3

File tree

7 files changed

+265
-308
lines changed

7 files changed

+265
-308
lines changed

tools/tensorflow_docs/api_generator/generate_lib.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def extract(py_modules,
215215
base_dir,
216216
private_map,
217217
visitor_cls=doc_generator_visitor.DocGeneratorVisitor,
218-
callbacks=None):
218+
callbacks=None,
219+
include_default_callbacks=True):
219220
"""Walks the module contents, returns an index of all visited objects.
220221
221222
The return value is an instance of `self._visitor_cls`, usually:
@@ -235,6 +236,9 @@ def extract(py_modules,
235236
`PublicApiFilter` and the accumulator (`DocGeneratorVisitor`). The
236237
primary use case for these is to filter the list of children (see:
237238
`public_api.local_definitions_filter`)
239+
include_default_callbacks: When true the long list of standard
240+
visitor-callbacks are included. When false, only the `callbacks` argument
241+
is used.
238242
239243
Returns:
240244
The accumulator (`DocGeneratorVisitor`)
@@ -246,16 +250,28 @@ def extract(py_modules,
246250
raise ValueError("only pass one [('name',module)] pair in py_modules")
247251
short_name, py_module = py_modules[0]
248252

249-
api_filter = public_api.PublicAPIFilter(
250-
base_dir=base_dir,
251-
private_map=private_map)
252-
253-
accumulator = visitor_cls()
254253

255254
# The objects found during traversal, and their children are passed to each
256255
# of these visitors in sequence. Each visitor returns the list of children
257256
# to be passed to the next visitor.
258-
visitors = [api_filter, public_api.ignore_typing] + callbacks + [accumulator]
257+
if include_default_callbacks:
258+
visitors = [
259+
# filter the api.
260+
public_api.FailIfNestedTooDeep(10),
261+
public_api.filter_module_all,
262+
public_api.add_proto_fields,
263+
public_api.filter_builtin_modules,
264+
public_api.filter_private_symbols,
265+
public_api.FilterBaseDirs(base_dir),
266+
public_api.FilterPrivateMap(private_map),
267+
public_api.filter_doc_controls_skip,
268+
public_api.ignore_typing
269+
]
270+
else:
271+
visitors = []
272+
273+
accumulator = visitor_cls()
274+
visitors = visitors + callbacks + [accumulator]
259275

260276
traverse.traverse(py_module, visitors, short_name)
261277

tools/tensorflow_docs/api_generator/public_api.py

Lines changed: 179 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,29 @@
1515
"""Visitor restricting traversal to only the public tensorflow API."""
1616

1717
import ast
18+
import dataclasses
1819
import inspect
1920
import os
21+
import sys
2022
import pathlib
2123
import textwrap
2224
import types
2325
import typing
24-
from typing import Any, Callable, List, Sequence, Tuple, Union
25-
26+
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Union
2627

2728
from tensorflow_docs.api_generator import doc_controls
2829
from tensorflow_docs.api_generator import doc_generator_visitor
2930
from tensorflow_docs.api_generator import get_source
3031

32+
from google.protobuf.message import Message as ProtoMessage
33+
3134
_TYPING_IDS = frozenset(
3235
id(obj)
3336
for obj in typing.__dict__.values()
3437
if not doc_generator_visitor.maybe_singleton(obj))
3538

3639

37-
Children = List[Tuple[str, Any]]
40+
Children = Iterable[Tuple[str, Any]]
3841
ApiFilter = Callable[[Tuple[str, ...], Any, Children], Children]
3942

4043

@@ -300,67 +303,195 @@ def explicit_package_contents_filter(path: Sequence[str], parent: Any,
300303
])
301304

302305

303-
class PublicAPIFilter(object):
304-
"""Visitor to use with `traverse` to filter just the public API."""
305-
306-
def __init__(self, base_dir, private_map=None):
307-
"""Constructor.
308-
309-
Args:
310-
base_dir: The directory to take source file paths relative to.
311-
private_map: A mapping from dotted path like "tf.symbol" to a list of
312-
names. Included names will not be listed at that location.
313-
"""
314-
self._base_dir = base_dir
315-
self._private_map = private_map or {}
316-
317-
def _is_private(self, path, parent, name, obj):
318-
"""Returns whether a name is private or not."""
306+
@dataclasses.dataclass
307+
class FailIfNestedTooDeep:
308+
max_depth: int
319309

320-
# Skip objects blocked by doc_controls.
321-
if doc_controls.should_skip(obj):
322-
return True
310+
def __call__(self, path: Sequence[str], parent: Any,
311+
children: Children) -> Children:
312+
if inspect.ismodule(parent) and len(path) > 10:
313+
raise RuntimeError('Modules nested too deep:\n\n{}\n\nThis is likely a '
314+
'problem with an accidental public import.'.format(
315+
'.'.join(path)))
316+
return children
323317

324-
if isinstance(parent, type):
325-
if doc_controls.should_skip_class_attr(parent, name):
326-
return True
327318

328-
if doc_controls.should_doc_private(obj):
329-
return False
319+
@dataclasses.dataclass
320+
class FilterBaseDirs:
321+
base_dirs: Sequence[pathlib.Path]
330322

331-
if inspect.ismodule(obj):
332-
mod_base_dirs = get_module_base_dirs(obj)
323+
def __call__(self, path: Sequence[str], parent: Any,
324+
children: Children) -> Children:
325+
for name, child in children:
326+
if not inspect.ismodule(child):
327+
yield name, child
328+
continue
329+
mod_base_dirs = get_module_base_dirs(child)
333330
# This check only handles normal packages/modules. Namespace-package
334331
# contents will get filtered when the submodules are checked.
335332
if len(mod_base_dirs) == 1:
336333
mod_base_dir = mod_base_dirs[0]
337334
# Check that module is in one of the `self._base_dir`s
338-
if not any(base in mod_base_dir.parents for base in self._base_dir):
339-
return True
335+
if not any(base in mod_base_dir.parents for base in self.base_dirs):
336+
continue
337+
yield name, child
338+
340339

341-
# Skip objects blocked by the private_map
342-
if name in self._private_map.get('.'.join(path), []):
343-
return True
340+
@dataclasses.dataclass
341+
class FilterPrivateMap:
342+
private_map: Dict[str, List[str]]
344343

344+
def __call__(self, path: Sequence[str], parent: Any,
345+
children: Children) -> Children:
346+
if self.private_map is None:
347+
yield from children
348+
349+
for name, child in children:
350+
if name in self.private_map.get('.'.join(path), []):
351+
continue
352+
yield (name, child)
353+
354+
355+
def filter_private_symbols(path: Sequence[str], parent: Any,
356+
children: Children) -> Children:
357+
del path
358+
del parent
359+
for name, child in children:
345360
# Skip "_" hidden attributes
346361
if name.startswith('_') and name not in ALLOWED_DUNDER_METHODS:
347-
return True
362+
if not doc_controls.should_doc_private(child):
363+
continue
364+
yield (name, child)
348365

349-
return False
350366

351-
def __call__(self, path: Sequence[str], parent: Any,
352-
children: Children) -> Children:
353-
"""Visitor interface, see `traverse` for details."""
367+
def filter_doc_controls_skip(path: Sequence[str], parent: Any,
368+
children: Children) -> Children:
369+
del path
370+
for name, child in children:
371+
if doc_controls.should_skip(child):
372+
continue
373+
if isinstance(parent, type):
374+
if doc_controls.should_skip_class_attr(parent, name):
375+
continue
376+
yield (name, child)
354377

355-
# Avoid long waits in cases of pretty unambiguous failure.
356-
if inspect.ismodule(parent) and len(path) > 10:
357-
raise RuntimeError('Modules nested too deep:\n\n{}\n\nThis is likely a '
358-
'problem with an accidental public import.'.format(
359-
'.'.join(path)))
360378

361-
# Remove things that are not visible.
362-
children = [(child_name, child_obj)
363-
for child_name, child_obj in list(children)
364-
if not self._is_private(path, parent, child_name, child_obj)]
379+
def filter_module_all(path: Sequence[str], parent: Any,
380+
children: Children) -> Children:
381+
"""Filters module children based on the "__all__" arrtibute.
382+
383+
Args:
384+
path: API to this symbol
385+
parent: The object
386+
children: A list of (name, object) pairs.
387+
388+
Returns:
389+
`children` filtered to respect __all__
390+
"""
391+
del path
392+
if not (inspect.ismodule(parent) and hasattr(parent, '__all__')):
393+
return children
394+
module_all = set(parent.__all__)
395+
children = [(name, value) for (name, value) in children if name in module_all]
396+
397+
return children
398+
399+
400+
def add_proto_fields(path: Sequence[str], parent: Any,
401+
children: Children) -> Children:
402+
"""Add properties to Proto classes, so they can be documented.
403+
404+
Warning: This inserts the Properties into the class so the rest of the system
405+
is unaffected. This patching is acceptable because there is never a reason to
406+
run other tensorflow code in the same process as the doc generator.
407+
408+
Args:
409+
path: API to this symbol
410+
parent: The object
411+
children: A list of (name, object) pairs.
412+
413+
Returns:
414+
`children` with proto fields added as properties.
415+
"""
416+
del path
417+
if not inspect.isclass(parent) or not issubclass(parent, ProtoMessage):
418+
return children
365419

420+
descriptor = getattr(parent, 'DESCRIPTOR', None)
421+
if descriptor is None:
366422
return children
423+
fields = descriptor.fields
424+
if not fields:
425+
return children
426+
427+
field = fields[0]
428+
# Make the dictionaries mapping from int types and labels to type and
429+
# label names.
430+
field_types = {
431+
getattr(field, name): name
432+
for name in dir(field)
433+
if name.startswith('TYPE')
434+
}
435+
436+
labels = {
437+
getattr(field, name): name
438+
for name in dir(field)
439+
if name.startswith('LABEL')
440+
}
441+
442+
field_properties = {}
443+
444+
for field in fields:
445+
name = field.name
446+
doc_parts = []
447+
448+
label = labels[field.label].lower().replace('label_', '')
449+
if label != 'optional':
450+
doc_parts.append(label)
451+
452+
type_name = field_types[field.type]
453+
if type_name == 'TYPE_MESSAGE':
454+
type_name = field.message_type.name
455+
elif type_name == 'TYPE_ENUM':
456+
type_name = field.enum_type.name
457+
else:
458+
type_name = type_name.lower().replace('type_', '')
459+
460+
doc_parts.append(type_name)
461+
doc_parts.append(name)
462+
doc = '`{}`'.format(' '.join(doc_parts))
463+
prop = property(fget=lambda x: x, doc=doc)
464+
field_properties[name] = prop
465+
466+
for name, prop in field_properties.items():
467+
setattr(parent, name, prop)
468+
469+
children = dict(children)
470+
children.update(field_properties)
471+
children = sorted(children.items(), key=lambda item: item[0])
472+
473+
return children
474+
475+
476+
def filter_builtin_modules(path: Sequence[str], parent: Any,
477+
children: Children) -> Children:
478+
"""Filters module children to remove builtin modules.
479+
480+
Args:
481+
path: API to this symbol
482+
parent: The object
483+
children: A list of (name, object) pairs.
484+
485+
Returns:
486+
`children` with all builtin modules removed.
487+
"""
488+
del path
489+
del parent
490+
# filter out 'builtin' modules
491+
filtered_children = []
492+
for name, child in children:
493+
# Do not descend into built-in modules
494+
if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names:
495+
continue
496+
filtered_children.append((name, child))
497+
return filtered_children

0 commit comments

Comments
 (0)