Skip to content

Commit 6ca4820

Browse files
Add support for Python 3.10+ X | Y style unions.
PiperOrigin-RevId: 547855197
1 parent 6118199 commit 6ca4820

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

tools/tensorflow_docs/api_generator/pretty_docs/type_alias_page.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
# ==============================================================================
1515
"""Bage builder classes for type alias pages."""
1616
import textwrap
17+
import types
1718
import typing
18-
from typing import Any, List, Dict
19+
from typing import Any, Dict, List
1920

2021
from tensorflow_docs.api_generator import parser
2122
from tensorflow_docs.api_generator import signature as signature_lib
@@ -74,15 +75,17 @@ def _custom_join(self, args: List[str], origin: str) -> str:
7475
origin: Origin of a type annotation object returned by `__origin__`.
7576
7677
Returns:
77-
A joined string containing the right representation of a type annotation.
78+
A joined string containing the representation of a type annotation.
7879
"""
7980
if 'Callable' in origin:
8081
if args[0] == '...':
81-
return ', '.join(args)
82+
return 'Callable[%s]' % ', '.join(args)
8283
else:
83-
return f"[{', '.join(args[:-1])}], {args[-1]}"
84+
return 'Callable[[%s], %s]' % (', '.join(args[:-1]), args[-1])
85+
elif 'UnionType' in origin:
86+
return ' | '.join(args)
8487

85-
return ', '.join(args)
88+
return '%s[%s]' % (origin, ', '.join(args))
8689

8790
def _link_type_args(self, obj: Any, reverse_index: Dict[int, str],
8891
linker: signature_lib.FormatArguments) -> str:
@@ -95,9 +98,8 @@ def _link_type_args(self, obj: Any, reverse_index: Dict[int, str],
9598
if getattr(obj, '__args__', None):
9699
for arg in obj.__args__:
97100
result.append(self._link_type_args(arg, reverse_index, linker))
98-
origin_str = typing._type_repr(obj.__origin__) # pylint: disable=protected-access # pytype: disable=module-attr
99-
result = self._custom_join(result, origin_str)
100-
return f'{origin_str}[{result}]'
101+
origin_str = typing._type_repr(typing.get_origin(obj)) # pylint: disable=protected-access # pytype: disable=module-attr
102+
return self._custom_join(result, origin_str)
101103
else:
102104
return typing._type_repr(obj) # pylint: disable=protected-access # pytype: disable=module-attr
103105

@@ -131,15 +133,15 @@ def collect_docs(self) -> None:
131133
linker = signature_lib.FormatArguments(parser_config=self.parser_config)
132134

133135
sig_args = []
134-
if self.py_object.__origin__:
136+
if typing.get_origin(self.py_object):
135137
for arg_obj in self.py_object.__args__:
136138
sig_args.append(
137139
self._link_type_args(arg_obj, self.parser_config.reverse_index,
138140
linker))
139141

140142
sig_args_str = textwrap.indent(',\n'.join(sig_args), ' ')
141-
if self.py_object.__origin__:
142-
origin_str = typing._type_repr(self.py_object.__origin__) # pylint: disable=protected-access # pytype: disable=module-attr
143+
if typing.get_origin(self.py_object):
144+
origin_str = typing._type_repr(typing.get_origin(self.py_object)) # pylint: disable=protected-access # pytype: disable=module-attr
143145
sig = f'{origin_str}[\n{sig_args_str}\n]'
144146
else:
145147
sig = repr(self.py_object)

0 commit comments

Comments
 (0)