25
25
import textwrap
26
26
import typing
27
27
28
- from typing import Any , Callable , Dict , List , Tuple , Type
28
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type
29
29
30
30
import astor
31
31
@@ -131,14 +131,6 @@ def strip_obj_addresses(text):
131
131
class FormatArguments (object ):
132
132
"""Formats the arguments and adds type annotations if they exist."""
133
133
134
- _INTERNAL_NAMES = {
135
- 'ops.GraphKeys' : 'tf.GraphKeys' ,
136
- '_ops.GraphKeys' : 'tf.GraphKeys' ,
137
- 'init_ops.zeros_initializer' : 'tf.zeros_initializer' ,
138
- 'init_ops.ones_initializer' : 'tf.ones_initializer' ,
139
- 'saver_pb2.SaverDef' : 'tf.train.SaverDef' ,
140
- }
141
-
142
134
# A regular expression capturing a python identifier.
143
135
_IDENTIFIER_RE = r'[a-zA-Z_]\w*'
144
136
@@ -166,9 +158,11 @@ def __init__(
166
158
self ._reverse_index = parser_config .reverse_index
167
159
self ._reference_resolver = parser_config .reference_resolver
168
160
169
- def get_link (self , obj_full_name : str ) -> str :
161
+ def get_link (self ,
162
+ link_text : str ,
163
+ obj_full_name : Optional [str ] = None ) -> str :
170
164
return self ._reference_resolver .python_link (
171
- link_text = obj_full_name , ref_full_name = obj_full_name )
165
+ link_text = link_text , ref_full_name = obj_full_name )
172
166
173
167
def _extract_non_builtin_types (self , arg_obj : Any ,
174
168
non_builtin_types : List [Any ]) -> List [Any ]:
@@ -252,25 +246,54 @@ def _linkify(self, non_builtin_map: Dict[str, Any], match) -> str:
252
246
253
247
return self .get_link (obj_full_name )
254
248
255
- def preprocess (self , ast_typehint : str , obj_anno : Any ) -> str :
249
+ def maybe_add_link (self , source : str , value : Any ) -> str :
250
+ """Return a link to an object's api page if found.
251
+
252
+ Args:
253
+ source: The source string from the code.
254
+ value: The value of the object.
255
+
256
+ Returns:
257
+ The original string with maybe an HTML link added.
258
+ """
259
+ cls = type (value )
260
+
261
+ value_name = self ._reverse_index .get (id (value ), None )
262
+ cls_name = self ._reverse_index .get (id (cls ), None )
263
+
264
+ if cls_name is not None :
265
+ # It's much more common for the class to be documented than the instance.
266
+ # and the class page will provide better docs.
267
+ before = source .split ('(' )[0 ]
268
+ cls_short_name = cls_name .split ('.' )[- 1 ]
269
+ if before .endswith (cls_short_name ):
270
+ # Yes, this is a guess but it will usually be right.
271
+ return self .get_link (source , cls_name )
272
+
273
+ if value_name is not None :
274
+ return self .get_link (value_name , value_name )
275
+
276
+ return source
277
+
278
+ def preprocess (self , string : str , value : Any ) -> str :
256
279
"""Links type annotations to its page if it exists.
257
280
258
281
Args:
259
- ast_typehint : AST extracted type annotation.
260
- obj_anno : Type annotation object.
282
+ string : AST extracted type annotation.
283
+ value : Type annotation object.
261
284
262
285
Returns:
263
286
Linked type annotation if the type annotation object exists.
264
287
"""
265
288
# If the object annotations exists in the reverse_index, get the link
266
289
# directly for the entire annotation.
267
- obj_anno_full_name = self ._reverse_index .get (id (obj_anno ), None )
290
+ obj_anno_full_name = self ._reverse_index .get (id (value ), None )
268
291
if obj_anno_full_name is not None :
269
292
return self .get_link (obj_anno_full_name )
270
293
271
- non_builtin_ast_types = self ._get_non_builtin_ast_types (ast_typehint )
294
+ non_builtin_ast_types = self ._get_non_builtin_ast_types (string )
272
295
try :
273
- non_builtin_type_objs = self ._extract_non_builtin_types (obj_anno , [])
296
+ non_builtin_type_objs = self ._extract_non_builtin_types (value , [])
274
297
except RecursionError :
275
298
non_builtin_type_objs = {}
276
299
@@ -282,16 +305,7 @@ def preprocess(self, ast_typehint: str, obj_anno: Any) -> str:
282
305
non_builtin_map = dict (zip (non_builtin_ast_types , non_builtin_type_objs ))
283
306
284
307
partial_func = functools .partial (self ._linkify , non_builtin_map )
285
- return self ._INDIVIDUAL_TYPES_RE .sub (partial_func , ast_typehint )
286
-
287
- def _replace_internal_names (self , default_text : str ) -> str :
288
- full_name_re = f'^{ self ._IDENTIFIER_RE } (.{ self ._IDENTIFIER_RE } )+'
289
- match = re .match (full_name_re , default_text )
290
- if match :
291
- for internal_name , public_name in self ._INTERNAL_NAMES .items ():
292
- if match .group (0 ).startswith (internal_name ):
293
- return public_name + default_text [len (internal_name ):]
294
- return default_text
308
+ return self ._INDIVIDUAL_TYPES_RE .sub (partial_func , string )
295
309
296
310
def format_return (self , return_anno : Tuple [Any , str ]) -> str :
297
311
value , source = return_anno
@@ -339,21 +353,11 @@ def format_kwargs(self, kwargs: List[inspect.Parameter]) -> List[str]:
339
353
default_text = None
340
354
if kwarg .default is not EMPTY :
341
355
default_val , default_source = kwarg .default
356
+ if default_source is None :
357
+ default_source = strip_obj_addresses (repr (default_val ))
358
+ default_source = html .escape (default_source )
342
359
343
- if id (default_val ) in self ._reverse_index :
344
- default_text = self ._reverse_index [id (default_val )]
345
- elif default_source is not None :
346
- default_text = default_source
347
- if default_text != repr (default_val ):
348
- default_text = self ._replace_internal_names (default_text )
349
- # Kwarg without default value.
350
- elif default_val is EMPTY :
351
- kwargs_text_repr .extend (self .format_args ([kwarg ]))
352
- continue
353
- else :
354
- # Strip object memory addresses to avoid unnecessary doc churn.
355
- default_text = strip_obj_addresses (repr (default_val ))
356
- default_text = html .escape (str (default_text ))
360
+ default_text = self .maybe_add_link (default_source , default_val )
357
361
358
362
# Format the kwargs to add the type annotation and default values.
359
363
typeanno = None
0 commit comments