Skip to content

Commit 69ddbac

Browse files
swegnercopybara-github
authored andcommitted
Fix kwonlyargs default value parsing for compulsory keyword-only args ordered before optional keyword-only args.
Previously, the logic in `FormatArguments.format_kwargs` had an assumption about the ordering of compulsory vs. optional kwonlyargs parameters. Optional kwonlyargs were assumed to come before non-optional kwonlyargs. In places where this wasn't the case, the default value would get attached to the wrong keyword-only arg parameter in the output signature. PiperOrigin-RevId: 411085094
1 parent bd24f51 commit 69ddbac

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

tools/tensorflow_docs/api_generator/parser.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,8 +1081,8 @@ class ASTDefaultValueExtractor(ast.NodeVisitor):
10811081
_PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)')
10821082

10831083
def __init__(self):
1084-
self.ast_args_defaults = []
1085-
self.ast_kw_only_defaults = []
1084+
self.ast_args_defaults = {}
1085+
self.ast_kw_only_defaults = {}
10861086

10871087
def _preprocess(self, val: str) -> str:
10881088
text_default_val = astor.to_source(val).strip().replace(
@@ -1093,15 +1093,20 @@ def _preprocess(self, val: str) -> str:
10931093
def visit_FunctionDef(self, node) -> None: # pylint: disable=invalid-name
10941094
"""Visits the `FunctionDef` node and extracts the default values."""
10951095

1096-
for default_val in node.args.defaults:
1096+
# From https://docs.python.org/3/library/ast.html#ast.arguments:
1097+
# `defaults` is a list of default values for arguments that can be passed
1098+
# positionally. If there are fewer defaults, they correspond to the last
1099+
# n arguments.
1100+
last_n_pos_args = node.args.args[-1 * len(node.args.defaults):]
1101+
for arg, default_val in zip(last_n_pos_args, node.args.defaults):
10971102
if default_val is not None:
10981103
text_default_val = self._preprocess(default_val)
1099-
self.ast_args_defaults.append(text_default_val)
1104+
self.ast_args_defaults[arg.arg] = text_default_val
11001105

1101-
for default_val in node.args.kw_defaults:
1106+
for kwarg, default_val in zip(node.args.kwonlyargs, node.args.kw_defaults):
11021107
if default_val is not None:
11031108
text_default_val = self._preprocess(default_val)
1104-
self.ast_kw_only_defaults.append(text_default_val)
1109+
self.ast_kw_only_defaults[kwarg.arg] = text_default_val
11051110

11061111

11071112
class FormatArguments(object):
@@ -1307,7 +1312,7 @@ def format_args(self, args: List[inspect.Parameter]) -> List[str]:
13071312
return args_text_repr
13081313

13091314
def format_kwargs(self, kwargs: List[inspect.Parameter],
1310-
ast_defaults: List[str]) -> List[str]:
1315+
ast_defaults: Dict[str, str]) -> List[str]:
13111316
"""Creates a text representation of the kwargs in a method/function.
13121317
13131318
Args:
@@ -1320,11 +1325,9 @@ def format_kwargs(self, kwargs: List[inspect.Parameter],
13201325

13211326
kwargs_text_repr = []
13221327

1323-
if len(ast_defaults) < len(kwargs):
1324-
ast_defaults.extend([None] * (len(kwargs) - len(ast_defaults))) # pytype: disable=container-type-mismatch
1325-
1326-
for kwarg, ast_default in zip(kwargs, ast_defaults):
1328+
for kwarg in kwargs:
13271329
kname = kwarg.name
1330+
ast_default = ast_defaults.get(kname)
13281331
default_val = kwarg.default
13291332

13301333
if id(default_val) in self._reverse_index:

tools/tensorflow_docs/api_generator/parser_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ def example_fun(arg1=a.b.c.d, arg2=a.b.c.d(1, 2), arg3=e['f']): # pylint: disab
13131313

13141314
def test_compulsory_kwargs_without_defaults(self):
13151315

1316-
def example_fun(x, z, a=True, b='test', *, y=None, c, **kwargs) -> bool: # pylint: disable=unused-argument
1316+
def example_fun(x, z, a=True, b='test', *, c, y=None, d, **kwargs) -> bool: # pylint: disable=unused-argument
13171317
return True
13181318

13191319
sig = parser.generate_signature(
@@ -1322,7 +1322,8 @@ def example_fun(x, z, a=True, b='test', *, y=None, c, **kwargs) -> bool: # pyli
13221322
func_full_name='',
13231323
func_type=parser.FuncType.FUNCTION)
13241324
self.assertEqual(sig.arguments, [
1325-
'x', 'z', 'a=True', 'b=&#x27;test&#x27;', '*', 'y=None', 'c', '**kwargs'
1325+
'x', 'z', 'a=True', 'b=&#x27;test&#x27;', '*', 'c', 'y=None', 'd',
1326+
'**kwargs'
13261327
])
13271328
self.assertEqual(sig.return_type, 'bool')
13281329
self.assertEqual(sig.arguments_typehint_exists, False)

0 commit comments

Comments
 (0)