14
14
# ==============================================================================
15
15
"""Bage builder classes for type alias pages."""
16
16
import textwrap
17
+ import types
17
18
import typing
18
- from typing import Any , List , Dict
19
+ from typing import Any , Dict , List
19
20
20
21
from tensorflow_docs .api_generator import parser
21
22
from tensorflow_docs .api_generator import signature as signature_lib
@@ -74,15 +75,17 @@ def _custom_join(self, args: List[str], origin: str) -> str:
74
75
origin: Origin of a type annotation object returned by `__origin__`.
75
76
76
77
Returns:
77
- A joined string containing the right representation of a type annotation.
78
+ A joined string containing the representation of a type annotation.
78
79
"""
79
80
if 'Callable' in origin :
80
81
if args [0 ] == '...' :
81
- return ', ' .join (args )
82
+ return 'Callable[%s]' % ' , ' .join (args )
82
83
else :
83
- return f"[{ ', ' .join (args [:- 1 ])} ], { args [- 1 ]} "
84
+ return 'Callable[[%s], %s]' % (', ' .join (args [:- 1 ]), args [- 1 ])
85
+ elif 'UnionType' in origin :
86
+ return ' | ' .join (args )
84
87
85
- return ', ' .join (args )
88
+ return '%s[%s]' % ( origin , ', ' .join (args ) )
86
89
87
90
def _link_type_args (self , obj : Any , reverse_index : Dict [int , str ],
88
91
linker : signature_lib .FormatArguments ) -> str :
@@ -95,9 +98,8 @@ def _link_type_args(self, obj: Any, reverse_index: Dict[int, str],
95
98
if getattr (obj , '__args__' , None ):
96
99
for arg in obj .__args__ :
97
100
result .append (self ._link_type_args (arg , reverse_index , linker ))
98
- origin_str = typing ._type_repr (obj .__origin__ ) # pylint: disable=protected-access # pytype: disable=module-attr
99
- result = self ._custom_join (result , origin_str )
100
- return f'{ origin_str } [{ result } ]'
101
+ origin_str = typing ._type_repr (typing .get_origin (obj )) # pylint: disable=protected-access # pytype: disable=module-attr
102
+ return self ._custom_join (result , origin_str )
101
103
else :
102
104
return typing ._type_repr (obj ) # pylint: disable=protected-access # pytype: disable=module-attr
103
105
@@ -131,15 +133,15 @@ def collect_docs(self) -> None:
131
133
linker = signature_lib .FormatArguments (parser_config = self .parser_config )
132
134
133
135
sig_args = []
134
- if self .py_object . __origin__ :
136
+ if typing . get_origin ( self .py_object ) :
135
137
for arg_obj in self .py_object .__args__ :
136
138
sig_args .append (
137
139
self ._link_type_args (arg_obj , self .parser_config .reverse_index ,
138
140
linker ))
139
141
140
142
sig_args_str = textwrap .indent (',\n ' .join (sig_args ), ' ' )
141
- if self .py_object . __origin__ :
142
- origin_str = typing ._type_repr (self .py_object . __origin__ ) # pylint: disable=protected-access # pytype: disable=module-attr
143
+ if typing . get_origin ( self .py_object ) :
144
+ origin_str = typing ._type_repr (typing . get_origin ( self .py_object ) ) # pylint: disable=protected-access # pytype: disable=module-attr
143
145
sig = f'{ origin_str } [\n { sig_args_str } \n ]'
144
146
else :
145
147
sig = repr (self .py_object )
0 commit comments