|
15 | 15 | """Visitor restricting traversal to only the public tensorflow API."""
|
16 | 16 |
|
17 | 17 | import ast
|
| 18 | +import dataclasses |
18 | 19 | import inspect
|
19 | 20 | import os
|
| 21 | +import sys |
20 | 22 | import pathlib
|
21 | 23 | import textwrap
|
22 | 24 | import types
|
23 | 25 | import typing
|
24 |
| -from typing import Any, Callable, List, Sequence, Tuple, Union |
25 |
| - |
| 26 | +from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Union |
26 | 27 |
|
27 | 28 | from tensorflow_docs.api_generator import doc_controls
|
28 | 29 | from tensorflow_docs.api_generator import doc_generator_visitor
|
29 | 30 | from tensorflow_docs.api_generator import get_source
|
30 | 31 |
|
| 32 | +from google.protobuf.message import Message as ProtoMessage |
| 33 | + |
31 | 34 | _TYPING_IDS = frozenset(
|
32 | 35 | id(obj)
|
33 | 36 | for obj in typing.__dict__.values()
|
34 | 37 | if not doc_generator_visitor.maybe_singleton(obj))
|
35 | 38 |
|
36 | 39 |
|
37 |
| -Children = List[Tuple[str, Any]] |
| 40 | +Children = Iterable[Tuple[str, Any]] |
38 | 41 | ApiFilter = Callable[[Tuple[str, ...], Any, Children], Children]
|
39 | 42 |
|
40 | 43 |
|
@@ -300,67 +303,195 @@ def explicit_package_contents_filter(path: Sequence[str], parent: Any,
|
300 | 303 | ])
|
301 | 304 |
|
302 | 305 |
|
303 |
| -class PublicAPIFilter(object): |
304 |
| - """Visitor to use with `traverse` to filter just the public API.""" |
305 |
| - |
306 |
| - def __init__(self, base_dir, private_map=None): |
307 |
| - """Constructor. |
308 |
| -
|
309 |
| - Args: |
310 |
| - base_dir: The directory to take source file paths relative to. |
311 |
| - private_map: A mapping from dotted path like "tf.symbol" to a list of |
312 |
| - names. Included names will not be listed at that location. |
313 |
| - """ |
314 |
| - self._base_dir = base_dir |
315 |
| - self._private_map = private_map or {} |
316 |
| - |
317 |
| - def _is_private(self, path, parent, name, obj): |
318 |
| - """Returns whether a name is private or not.""" |
| 306 | +@dataclasses.dataclass |
| 307 | +class FailIfNestedTooDeep: |
| 308 | + max_depth: int |
319 | 309 |
|
320 |
| - # Skip objects blocked by doc_controls. |
321 |
| - if doc_controls.should_skip(obj): |
322 |
| - return True |
| 310 | + def __call__(self, path: Sequence[str], parent: Any, |
| 311 | + children: Children) -> Children: |
| 312 | + if inspect.ismodule(parent) and len(path) > 10: |
| 313 | + raise RuntimeError('Modules nested too deep:\n\n{}\n\nThis is likely a ' |
| 314 | + 'problem with an accidental public import.'.format( |
| 315 | + '.'.join(path))) |
| 316 | + return children |
323 | 317 |
|
324 |
| - if isinstance(parent, type): |
325 |
| - if doc_controls.should_skip_class_attr(parent, name): |
326 |
| - return True |
327 | 318 |
|
328 |
| - if doc_controls.should_doc_private(obj): |
329 |
| - return False |
| 319 | +@dataclasses.dataclass |
| 320 | +class FilterBaseDirs: |
| 321 | + base_dirs: Sequence[pathlib.Path] |
330 | 322 |
|
331 |
| - if inspect.ismodule(obj): |
332 |
| - mod_base_dirs = get_module_base_dirs(obj) |
| 323 | + def __call__(self, path: Sequence[str], parent: Any, |
| 324 | + children: Children) -> Children: |
| 325 | + for name, child in children: |
| 326 | + if not inspect.ismodule(child): |
| 327 | + yield name, child |
| 328 | + continue |
| 329 | + mod_base_dirs = get_module_base_dirs(child) |
333 | 330 | # This check only handles normal packages/modules. Namespace-package
|
334 | 331 | # contents will get filtered when the submodules are checked.
|
335 | 332 | if len(mod_base_dirs) == 1:
|
336 | 333 | mod_base_dir = mod_base_dirs[0]
|
337 | 334 | # Check that module is in one of the `self._base_dir`s
|
338 |
| - if not any(base in mod_base_dir.parents for base in self._base_dir): |
339 |
| - return True |
| 335 | + if not any(base in mod_base_dir.parents for base in self.base_dirs): |
| 336 | + continue |
| 337 | + yield name, child |
| 338 | + |
340 | 339 |
|
341 |
| - # Skip objects blocked by the private_map |
342 |
| - if name in self._private_map.get('.'.join(path), []): |
343 |
| - return True |
| 340 | +@dataclasses.dataclass |
| 341 | +class FilterPrivateMap: |
| 342 | + private_map: Dict[str, List[str]] |
344 | 343 |
|
| 344 | + def __call__(self, path: Sequence[str], parent: Any, |
| 345 | + children: Children) -> Children: |
| 346 | + if self.private_map is None: |
| 347 | + yield from children |
| 348 | + |
| 349 | + for name, child in children: |
| 350 | + if name in self.private_map.get('.'.join(path), []): |
| 351 | + continue |
| 352 | + yield (name, child) |
| 353 | + |
| 354 | + |
| 355 | +def filter_private_symbols(path: Sequence[str], parent: Any, |
| 356 | + children: Children) -> Children: |
| 357 | + del path |
| 358 | + del parent |
| 359 | + for name, child in children: |
345 | 360 | # Skip "_" hidden attributes
|
346 | 361 | if name.startswith('_') and name not in ALLOWED_DUNDER_METHODS:
|
347 |
| - return True |
| 362 | + if not doc_controls.should_doc_private(child): |
| 363 | + continue |
| 364 | + yield (name, child) |
348 | 365 |
|
349 |
| - return False |
350 | 366 |
|
351 |
| - def __call__(self, path: Sequence[str], parent: Any, |
352 |
| - children: Children) -> Children: |
353 |
| - """Visitor interface, see `traverse` for details.""" |
| 367 | +def filter_doc_controls_skip(path: Sequence[str], parent: Any, |
| 368 | + children: Children) -> Children: |
| 369 | + del path |
| 370 | + for name, child in children: |
| 371 | + if doc_controls.should_skip(child): |
| 372 | + continue |
| 373 | + if isinstance(parent, type): |
| 374 | + if doc_controls.should_skip_class_attr(parent, name): |
| 375 | + continue |
| 376 | + yield (name, child) |
354 | 377 |
|
355 |
| - # Avoid long waits in cases of pretty unambiguous failure. |
356 |
| - if inspect.ismodule(parent) and len(path) > 10: |
357 |
| - raise RuntimeError('Modules nested too deep:\n\n{}\n\nThis is likely a ' |
358 |
| - 'problem with an accidental public import.'.format( |
359 |
| - '.'.join(path))) |
360 | 378 |
|
361 |
| - # Remove things that are not visible. |
362 |
| - children = [(child_name, child_obj) |
363 |
| - for child_name, child_obj in list(children) |
364 |
| - if not self._is_private(path, parent, child_name, child_obj)] |
| 379 | +def filter_module_all(path: Sequence[str], parent: Any, |
| 380 | + children: Children) -> Children: |
| 381 | + """Filters module children based on the "__all__" arrtibute. |
| 382 | +
|
| 383 | + Args: |
| 384 | + path: API to this symbol |
| 385 | + parent: The object |
| 386 | + children: A list of (name, object) pairs. |
| 387 | +
|
| 388 | + Returns: |
| 389 | + `children` filtered to respect __all__ |
| 390 | + """ |
| 391 | + del path |
| 392 | + if not (inspect.ismodule(parent) and hasattr(parent, '__all__')): |
| 393 | + return children |
| 394 | + module_all = set(parent.__all__) |
| 395 | + children = [(name, value) for (name, value) in children if name in module_all] |
| 396 | + |
| 397 | + return children |
| 398 | + |
| 399 | + |
| 400 | +def add_proto_fields(path: Sequence[str], parent: Any, |
| 401 | + children: Children) -> Children: |
| 402 | + """Add properties to Proto classes, so they can be documented. |
| 403 | +
|
| 404 | + Warning: This inserts the Properties into the class so the rest of the system |
| 405 | + is unaffected. This patching is acceptable because there is never a reason to |
| 406 | + run other tensorflow code in the same process as the doc generator. |
| 407 | +
|
| 408 | + Args: |
| 409 | + path: API to this symbol |
| 410 | + parent: The object |
| 411 | + children: A list of (name, object) pairs. |
| 412 | +
|
| 413 | + Returns: |
| 414 | + `children` with proto fields added as properties. |
| 415 | + """ |
| 416 | + del path |
| 417 | + if not inspect.isclass(parent) or not issubclass(parent, ProtoMessage): |
| 418 | + return children |
365 | 419 |
|
| 420 | + descriptor = getattr(parent, 'DESCRIPTOR', None) |
| 421 | + if descriptor is None: |
366 | 422 | return children
|
| 423 | + fields = descriptor.fields |
| 424 | + if not fields: |
| 425 | + return children |
| 426 | + |
| 427 | + field = fields[0] |
| 428 | + # Make the dictionaries mapping from int types and labels to type and |
| 429 | + # label names. |
| 430 | + field_types = { |
| 431 | + getattr(field, name): name |
| 432 | + for name in dir(field) |
| 433 | + if name.startswith('TYPE') |
| 434 | + } |
| 435 | + |
| 436 | + labels = { |
| 437 | + getattr(field, name): name |
| 438 | + for name in dir(field) |
| 439 | + if name.startswith('LABEL') |
| 440 | + } |
| 441 | + |
| 442 | + field_properties = {} |
| 443 | + |
| 444 | + for field in fields: |
| 445 | + name = field.name |
| 446 | + doc_parts = [] |
| 447 | + |
| 448 | + label = labels[field.label].lower().replace('label_', '') |
| 449 | + if label != 'optional': |
| 450 | + doc_parts.append(label) |
| 451 | + |
| 452 | + type_name = field_types[field.type] |
| 453 | + if type_name == 'TYPE_MESSAGE': |
| 454 | + type_name = field.message_type.name |
| 455 | + elif type_name == 'TYPE_ENUM': |
| 456 | + type_name = field.enum_type.name |
| 457 | + else: |
| 458 | + type_name = type_name.lower().replace('type_', '') |
| 459 | + |
| 460 | + doc_parts.append(type_name) |
| 461 | + doc_parts.append(name) |
| 462 | + doc = '`{}`'.format(' '.join(doc_parts)) |
| 463 | + prop = property(fget=lambda x: x, doc=doc) |
| 464 | + field_properties[name] = prop |
| 465 | + |
| 466 | + for name, prop in field_properties.items(): |
| 467 | + setattr(parent, name, prop) |
| 468 | + |
| 469 | + children = dict(children) |
| 470 | + children.update(field_properties) |
| 471 | + children = sorted(children.items(), key=lambda item: item[0]) |
| 472 | + |
| 473 | + return children |
| 474 | + |
| 475 | + |
| 476 | +def filter_builtin_modules(path: Sequence[str], parent: Any, |
| 477 | + children: Children) -> Children: |
| 478 | + """Filters module children to remove builtin modules. |
| 479 | +
|
| 480 | + Args: |
| 481 | + path: API to this symbol |
| 482 | + parent: The object |
| 483 | + children: A list of (name, object) pairs. |
| 484 | +
|
| 485 | + Returns: |
| 486 | + `children` with all builtin modules removed. |
| 487 | + """ |
| 488 | + del path |
| 489 | + del parent |
| 490 | + # filter out 'builtin' modules |
| 491 | + filtered_children = [] |
| 492 | + for name, child in children: |
| 493 | + # Do not descend into built-in modules |
| 494 | + if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names: |
| 495 | + continue |
| 496 | + filtered_children.append((name, child)) |
| 497 | + return filtered_children |
0 commit comments