Skip to content

Commit ee52c35

Browse files
MarkDaoustcopybara-github
authored andcommitted
Add a simplification for ast.parse
This was failing on the wrapped lines in Dataset.map (and all subclasses) because textwrap.dedent didn't remove the leading space. PiperOrigin-RevId: 438447039
1 parent 0d2801b commit ee52c35

File tree

4 files changed

+31
-20
lines changed

4 files changed

+31
-20
lines changed

tools/tensorflow_docs/api_generator/get_source.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,27 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Simple get_source."""
16+
import ast
1617
import inspect
1718
import textwrap
1819

1920
from typing import Any, Optional, Sequence, Tuple
2021

2122

23+
def get_ast(py_object) -> Optional[ast.AST]:
24+
if isinstance(py_object, str):
25+
source = textwrap.dedent(py_object)
26+
else:
27+
source = get_source(py_object)
28+
if source is None:
29+
return None
30+
31+
try:
32+
return ast.parse(source)
33+
except Exception: # pylint: disable=broad-except
34+
return None
35+
36+
2237
def get_source(py_object: Any) -> Optional[str]:
2338
if py_object is not None:
2439
try:

tools/tensorflow_docs/api_generator/public_api.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,10 @@ def visit_Import(self, node): # pylint: disable=invalid-name
194194
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
195195
self._add_imported_symbol(node)
196196

197-
if isinstance(obj, str):
198-
source = textwrap.dedent(obj)
199-
else:
200-
source = get_source.get_source(obj)
201-
if source is None:
197+
tree = get_source.get_ast(obj)
198+
if tree is None:
202199
return []
203200

204-
tree = ast.parse(source)
205201
visitor = ImportNodeVisitor()
206202
visitor.visit(tree)
207203
return visitor.imported_symbols

tools/tensorflow_docs/api_generator/report/linter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,13 @@ def lint_returns(
176176
Returns:
177177
A filled `ReturnLint` proto object.
178178
"""
179-
source = get_source.get_source(page_info.py_object)
180-
181179
return_visitor = ReturnVisitor()
182-
if source is not None:
180+
181+
source = get_source.get_source(page_info.py_object)
182+
obj_ast = get_source.get_ast(page_info.py_object)
183+
if obj_ast is not None:
183184
try:
184-
return_visitor.visit(ast.parse(source))
185+
return_visitor.visit(obj_ast)
185186
except Exception: # pylint: disable=broad-except
186187
pass
187188

@@ -236,12 +237,13 @@ def lint_raises(page_info: base_page.PageInfo) -> api_report_pb2.RaisesLint:
236237

237238
# Extract the raises from the source code.
238239
raise_visitor = RaiseVisitor()
239-
source = get_source.get_source(page_info.py_object)
240-
if source is not None:
240+
obj_ast = get_source.get_ast(page_info.py_object)
241+
if obj_ast is not None:
241242
try:
242-
raise_visitor.visit(ast.parse(source))
243+
raise_visitor.visit(obj_ast)
243244
except Exception: # pylint: disable=broad-except
244245
pass
246+
245247
raises_lint.total_raises_in_code = len(raise_visitor.total_raises)
246248

247249
# Extract the raises defined in the docstring.

tools/tensorflow_docs/api_generator/signature.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,8 @@ def _preprocess_default(self, val: ast.AST) -> str:
5656
return text_default_val
5757

5858
def extract(self, obj: Any):
59-
obj_source = get_source.get_source(obj)
60-
if obj_source is not None:
61-
obj_ast = ast.parse(obj_source)
59+
obj_ast = get_source.get_ast(obj)
60+
if obj_ast is not None:
6261
self.visit(obj_ast)
6362

6463

@@ -674,11 +673,10 @@ def visit_FunctionDef(self, node): # pylint: disable=invalid-name
674673

675674
visitor = ASTDecoratorExtractor()
676675

677-
# Note: inspect.getsource doesn't include the decorator lines on classes,
676+
# Note: get_source doesn't include the decorator lines on classes,
678677
# this won't work for classes until that's fixed.
679-
func_source = get_source.get_source(func)
680-
if func_source is not None:
681-
func_ast = ast.parse(func_source)
678+
func_ast = get_source.get_ast(func)
679+
if func_ast is not None:
682680
visitor.visit(func_ast)
683681

684682
return visitor.decorator_list

0 commit comments

Comments
 (0)