Skip to content

Commit 9edcd05

Browse files
MarkDaoustcopybara-github
authored andcommitted
Link signature defaults for known objects to API locations.
When a default contains an instance, prefer to link to the **class** if available. Most of the time if you link to an instance it's a dataclass default. For the main TF API this mainly links simple defaults like enums, dtypes and activation functions back to their class pages. For tensorflow_models this makes the pages for complex nested configs like [OptimizerConfig](https://github.com/tensorflow/models/blob/3d0e12fdd61f1e9e0515b5caed2a33635ae0c8c3/official/modeling/optimization/configs/optimization_config.py#L32) more useful by linking all the nested dataclass instance pages back to their class' page. + drop some TF1 special cases. PiperOrigin-RevId: 436614177
1 parent c0c4ab3 commit 9edcd05

File tree

3 files changed

+66
-46
lines changed

3 files changed

+66
-46
lines changed

tools/tensorflow_docs/api_generator/reference_resolver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def one_ref(match):
290290

291291
return '\n'.join(fixed_lines)
292292

293-
def python_link(self, link_text, ref_full_name):
293+
def python_link(self, link_text: str, ref_full_name: Optional[str] = None):
294294
"""Resolve a "`tf.symbol`" reference to a link.
295295
296296
This will pick the canonical location for duplicate symbols.
@@ -302,6 +302,8 @@ def python_link(self, link_text, ref_full_name):
302302
Returns:
303303
A link to the documentation page of `ref_full_name`.
304304
"""
305+
if ref_full_name is None:
306+
ref_full_name = link_text
305307
link_text = html.escape(link_text, quote=True)
306308

307309
url = self.reference_to_url(ref_full_name)

tools/tensorflow_docs/api_generator/signature.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import textwrap
2626
import typing
2727

28-
from typing import Any, Callable, Dict, List, Tuple, Type
28+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
2929

3030
import astor
3131

@@ -131,14 +131,6 @@ def strip_obj_addresses(text):
131131
class FormatArguments(object):
132132
"""Formats the arguments and adds type annotations if they exist."""
133133

134-
_INTERNAL_NAMES = {
135-
'ops.GraphKeys': 'tf.GraphKeys',
136-
'_ops.GraphKeys': 'tf.GraphKeys',
137-
'init_ops.zeros_initializer': 'tf.zeros_initializer',
138-
'init_ops.ones_initializer': 'tf.ones_initializer',
139-
'saver_pb2.SaverDef': 'tf.train.SaverDef',
140-
}
141-
142134
# A regular expression capturing a python identifier.
143135
_IDENTIFIER_RE = r'[a-zA-Z_]\w*'
144136

@@ -166,9 +158,11 @@ def __init__(
166158
self._reverse_index = parser_config.reverse_index
167159
self._reference_resolver = parser_config.reference_resolver
168160

169-
def get_link(self, obj_full_name: str) -> str:
161+
def get_link(self,
162+
link_text: str,
163+
obj_full_name: Optional[str] = None) -> str:
170164
return self._reference_resolver.python_link(
171-
link_text=obj_full_name, ref_full_name=obj_full_name)
165+
link_text=link_text, ref_full_name=obj_full_name)
172166

173167
def _extract_non_builtin_types(self, arg_obj: Any,
174168
non_builtin_types: List[Any]) -> List[Any]:
@@ -252,25 +246,54 @@ def _linkify(self, non_builtin_map: Dict[str, Any], match) -> str:
252246

253247
return self.get_link(obj_full_name)
254248

255-
def preprocess(self, ast_typehint: str, obj_anno: Any) -> str:
249+
def maybe_add_link(self, source: str, value: Any) -> str:
250+
"""Return a link to an object's api page if found.
251+
252+
Args:
253+
source: The source string from the code.
254+
value: The value of the object.
255+
256+
Returns:
257+
The original string with maybe an HTML link added.
258+
"""
259+
cls = type(value)
260+
261+
value_name = self._reverse_index.get(id(value), None)
262+
cls_name = self._reverse_index.get(id(cls), None)
263+
264+
if cls_name is not None:
265+
# It's much more common for the class to be documented than the instance.
266+
# and the class page will provide better docs.
267+
before = source.split('(')[0]
268+
cls_short_name = cls_name.split('.')[-1]
269+
if before.endswith(cls_short_name):
270+
# Yes, this is a guess but it will usually be right.
271+
return self.get_link(source, cls_name)
272+
273+
if value_name is not None:
274+
return self.get_link(value_name, value_name)
275+
276+
return source
277+
278+
def preprocess(self, string: str, value: Any) -> str:
256279
"""Links type annotations to its page if it exists.
257280
258281
Args:
259-
ast_typehint: AST extracted type annotation.
260-
obj_anno: Type annotation object.
282+
string: AST extracted type annotation.
283+
value: Type annotation object.
261284
262285
Returns:
263286
Linked type annotation if the type annotation object exists.
264287
"""
265288
# If the object annotations exists in the reverse_index, get the link
266289
# directly for the entire annotation.
267-
obj_anno_full_name = self._reverse_index.get(id(obj_anno), None)
290+
obj_anno_full_name = self._reverse_index.get(id(value), None)
268291
if obj_anno_full_name is not None:
269292
return self.get_link(obj_anno_full_name)
270293

271-
non_builtin_ast_types = self._get_non_builtin_ast_types(ast_typehint)
294+
non_builtin_ast_types = self._get_non_builtin_ast_types(string)
272295
try:
273-
non_builtin_type_objs = self._extract_non_builtin_types(obj_anno, [])
296+
non_builtin_type_objs = self._extract_non_builtin_types(value, [])
274297
except RecursionError:
275298
non_builtin_type_objs = {}
276299

@@ -282,16 +305,7 @@ def preprocess(self, ast_typehint: str, obj_anno: Any) -> str:
282305
non_builtin_map = dict(zip(non_builtin_ast_types, non_builtin_type_objs))
283306

284307
partial_func = functools.partial(self._linkify, non_builtin_map)
285-
return self._INDIVIDUAL_TYPES_RE.sub(partial_func, ast_typehint)
286-
287-
def _replace_internal_names(self, default_text: str) -> str:
288-
full_name_re = f'^{self._IDENTIFIER_RE}(.{self._IDENTIFIER_RE})+'
289-
match = re.match(full_name_re, default_text)
290-
if match:
291-
for internal_name, public_name in self._INTERNAL_NAMES.items():
292-
if match.group(0).startswith(internal_name):
293-
return public_name + default_text[len(internal_name):]
294-
return default_text
308+
return self._INDIVIDUAL_TYPES_RE.sub(partial_func, string)
295309

296310
def format_return(self, return_anno: Tuple[Any, str]) -> str:
297311
value, source = return_anno
@@ -339,21 +353,11 @@ def format_kwargs(self, kwargs: List[inspect.Parameter]) -> List[str]:
339353
default_text = None
340354
if kwarg.default is not EMPTY:
341355
default_val, default_source = kwarg.default
356+
if default_source is None:
357+
default_source = strip_obj_addresses(repr(default_val))
358+
default_source = html.escape(default_source)
342359

343-
if id(default_val) in self._reverse_index:
344-
default_text = self._reverse_index[id(default_val)]
345-
elif default_source is not None:
346-
default_text = default_source
347-
if default_text != repr(default_val):
348-
default_text = self._replace_internal_names(default_text)
349-
# Kwarg without default value.
350-
elif default_val is EMPTY:
351-
kwargs_text_repr.extend(self.format_args([kwarg]))
352-
continue
353-
else:
354-
# Strip object memory addresses to avoid unnecessary doc churn.
355-
default_text = strip_obj_addresses(repr(default_val))
356-
default_text = html.escape(str(default_text))
360+
default_text = self.maybe_add_link(default_source, default_val)
357361

358362
# Format the kwargs to add the type annotation and default values.
359363
typeanno = None

tools/tensorflow_docs/api_generator/signature_test.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,29 @@ def setUp(self):
5050
super().setUp()
5151
self.known_object = object()
5252
reference_resolver = reference_resolver_lib.ReferenceResolver(
53-
duplicate_of={},
53+
link_prefix='/',
54+
duplicate_of={
55+
'tfdocs.api_generator.signature.extract_decorators':
56+
'tfdocs.api_generator.signature.extract_decorators',
57+
'location.of.object.in.api':
58+
'location.of.object.in.api',
59+
},
5460
is_fragment={
55-
'tfdocs.api_generator.signature.extract_decorators': False
61+
'location.of.object.in.api': False,
62+
'tfdocs.api_generator.signature.extract_decorators': False,
5663
},
5764
py_module_names=[])
5865
self.parser_config = config.ParserConfig(
5966
reference_resolver=reference_resolver,
6067
duplicates={},
6168
duplicate_of={},
6269
tree={},
63-
index={},
70+
index={
71+
'location.of.object.in.api':
72+
self.known_object,
73+
'tfdocs.api_generator.signature.extract_decorators':
74+
signature.extract_decorators
75+
},
6476
reverse_index={
6577
id(self.known_object):
6678
'location.of.object.in.api',
@@ -79,7 +91,9 @@ def example_fun(arg=self.known_object): # pylint: disable=unused-argument
7991
example_fun,
8092
parser_config=self.parser_config,
8193
func_type=signature.FuncType.FUNCTION)
82-
self.assertEqual('(\n arg=location.of.object.in.api\n)', str(sig))
94+
self.assertEqual(
95+
'(\n arg=<a href="/location/of/object/in/api.md"><code>location.of.object.in.api</code></a>\n)',
96+
str(sig))
8397

8498
def test_literals(self):
8599

0 commit comments

Comments
 (0)