Skip to content

Commit 00ac423

Browse files
zou3519pytorchmergebot
authored andcommitted
[Dynamo] stop import third-party astunparse (pytorch#142503)
PyTorch's minimum version is 3.9, so we can now use ast.unparse. Test Plan: - wait for tests Pull Request resolved: pytorch#142503 Approved by: https://github.com/StrongerXi, https://github.com/yanboliang, https://github.com/mlazos ghstack dependencies: pytorch#142502
1 parent 0268abd commit 00ac423

File tree

2 files changed

+7
-88
lines changed

2 files changed

+7
-88
lines changed

test/dynamo/test_misc.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8261,21 +8261,8 @@ class CSETestCase:
82618261
expr: str
82628262
preface: typing.List[str] = dataclasses.field(default_factory=list)
82638263
expected: typing.Optional[str] = None
8264-
expected_py38: typing.Optional[str] = None
8265-
8266-
def _is_py38(self) -> bool:
8267-
return sys.version_info[:2] <= (3, 8)
8268-
8269-
def _has_ast_unparse(self) -> bool:
8270-
from torch._dynamo.guards import HAS_UNPARSE_FUNCTIONS
8271-
8272-
return HAS_UNPARSE_FUNCTIONS
82738264

82748265
def test_guards_cse_pass_single(self):
8275-
if not self._has_ast_unparse():
8276-
if IS_FBCODE:
8277-
raise RuntimeError("Needs astunparse or Python-3.9+")
8278-
raise unittest.SkipTest("Needs astunparse or Python-3.9+")
82798266
from torch._dynamo.guards import PyExprCSEPass
82808267

82818268
testcase = self.CSETestCase
@@ -8320,34 +8307,28 @@ def test_guards_cse_pass_single(self):
83208307
self.assertEqual(expr, expected)
83218308

83228309
def test_guards_cse_pass_multiple(self):
8323-
if not self._has_ast_unparse():
8324-
raise unittest.SkipTest("Needs astunparse or Python-3.9+")
83258310
from torch._dynamo.guards import PyExprCSEPass
83268311

83278312
testcase = self.CSETestCase
83288313
testcases = [
83298314
testcase(
83308315
expr="x[0].a < x[1].a * (3 - x[2].a)",
83318316
expected="x[0].a < x[1].a * (3 - x[2].a)",
8332-
expected_py38="(x[0].a < (x[1].a * (3 - x[2].a)))",
83338317
),
83348318
testcase(
83358319
expr="a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0",
83368320
preface=["_var0 = a.b", "_var1 = _var0.c"],
83378321
expected="_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0",
8338-
expected_py38="((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)",
83398322
),
83408323
testcase(
83418324
expr="f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512",
83428325
preface=["_var2 = m.n", "_var3 = _var2[0]"],
83438326
expected="f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512",
8344-
expected_py38="(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)",
83458327
),
83468328
testcase(
83478329
expr="self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k",
83488330
preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
83498331
expected="_var6 + (1 - _var6) <= m[0].a + _var6",
8350-
expected_py38="((_var6 + (1 - _var6)) <= (m[0].a + _var6))",
83518332
),
83528333
]
83538334

@@ -8357,7 +8338,7 @@ def test_guards_cse_pass_multiple(self):
83578338
for t in testcases:
83588339
preface, expr = csepass.replace(t.expr)
83598340
self.assertEqual(preface, t.preface)
8360-
expected = t.expected_py38 if self._is_py38() else t.expected
8341+
expected = t.expected
83618342
expected = expected if expected is not None else t.expr
83628343
self.assertEqual(expr, expected)
83638344

@@ -8393,46 +8374,7 @@ def guard(L):
83938374
return True
83948375
return guard
83958376
"""
8396-
expected_38 = """\
8397-
def ___make_guard_fn():
8398-
def guard(L):
8399-
if not ((x[0].a < (x[1].a * (3 - x[2].a)))):
8400-
return False
8401-
_var0 = a.b
8402-
_var1 = _var0.c
8403-
if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)):
8404-
return False
8405-
_var2 = m.n
8406-
_var3 = _var2[0]
8407-
if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)):
8408-
return False
8409-
_var4 = self.g
8410-
_var5 = _var4(a, b)
8411-
_var6 = _var5.k
8412-
if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))):
8413-
return False
8414-
return True
8415-
return guard
8416-
"""
8417-
expected_38_no_astunparse = """\
8418-
def ___make_guard_fn():
8419-
def guard(L):
8420-
if not (x[0].a < x[1].a * (3 - x[2].a)):
8421-
return False
8422-
if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0):
8423-
return False
8424-
if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512):
8425-
return False
8426-
if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k):
8427-
return False
8428-
return True
8429-
return guard
8430-
"""
84318377

8432-
if self._is_py38():
8433-
expected = (
8434-
expected_38 if self._has_ast_unparse() else expected_38_no_astunparse
8435-
)
84368378
self.assertEqual(expected, pycode)
84378379

84388380
def test_dynamo_compiling_fake_tensor_to_vararg_int(self):

torch/_dynamo/guards.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -439,25 +439,8 @@ def _get_closure_vars():
439439
return _CLOSURE_VARS
440440

441441

442-
if sys.version_info[:2] <= (3, 8):
443-
# [Note: Python Version <= 3.8]
444-
# This branch should be dropped when we drop support for Python 3.8.
445-
# Reason: 'ast.unparse' function was introduced in Python 3.9.
446-
447-
try:
448-
import astunparse # type: ignore[import]
449-
450-
def _ast_unparse(node: ast.AST) -> str:
451-
return astunparse.unparse(node).replace("\n", "")
452-
453-
HAS_UNPARSE_FUNCTIONS = True
454-
except ImportError:
455-
HAS_UNPARSE_FUNCTIONS = False
456-
else:
457-
HAS_UNPARSE_FUNCTIONS = True
458-
459-
def _ast_unparse(node: ast.AST) -> str:
460-
return ast.unparse(node).replace("\n", "")
442+
def _ast_unparse(node: ast.AST) -> str:
443+
return ast.unparse(node).replace("\n", "")
461444

462445

463446
def strip_function_call(name):
@@ -2588,17 +2571,11 @@ def lookup_weakrefs(self, obj):
25882571
def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
25892572
from torch._inductor.utils import IndentedBuffer
25902573

2591-
if HAS_UNPARSE_FUNCTIONS:
2592-
csepass = PyExprCSEPass()
2593-
csepass.count(code_parts)
2594-
2595-
def replace(expr: str) -> Tuple[List[str], str]:
2596-
return csepass.replace(expr)
2597-
2598-
else:
2574+
csepass = PyExprCSEPass()
2575+
csepass.count(code_parts)
25992576

2600-
def replace(expr: str) -> Tuple[List[str], str]:
2601-
return [], expr
2577+
def replace(expr: str) -> Tuple[List[str], str]:
2578+
return csepass.replace(expr)
26022579

26032580
# Generate the inner body of the guard function.
26042581
# i.e. if-chain of the guard expressions.

0 commit comments

Comments
 (0)