@@ -8261,21 +8261,8 @@ class CSETestCase:
82618261 expr : str
82628262 preface : typing .List [str ] = dataclasses .field (default_factory = list )
82638263 expected : typing .Optional [str ] = None
8264- expected_py38 : typing .Optional [str ] = None
8265-
8266- def _is_py38 (self ) -> bool :
8267- return sys .version_info [:2 ] <= (3 , 8 )
8268-
8269- def _has_ast_unparse (self ) -> bool :
8270- from torch ._dynamo .guards import HAS_UNPARSE_FUNCTIONS
8271-
8272- return HAS_UNPARSE_FUNCTIONS
82738264
82748265 def test_guards_cse_pass_single (self ):
8275- if not self ._has_ast_unparse ():
8276- if IS_FBCODE :
8277- raise RuntimeError ("Needs astunparse or Python-3.9+" )
8278- raise unittest .SkipTest ("Needs astunparse or Python-3.9+" )
82798266 from torch ._dynamo .guards import PyExprCSEPass
82808267
82818268 testcase = self .CSETestCase
@@ -8320,34 +8307,28 @@ def test_guards_cse_pass_single(self):
83208307 self .assertEqual (expr , expected )
83218308
83228309 def test_guards_cse_pass_multiple (self ):
8323- if not self ._has_ast_unparse ():
8324- raise unittest .SkipTest ("Needs astunparse or Python-3.9+" )
83258310 from torch ._dynamo .guards import PyExprCSEPass
83268311
83278312 testcase = self .CSETestCase
83288313 testcases = [
83298314 testcase (
83308315 expr = "x[0].a < x[1].a * (3 - x[2].a)" ,
83318316 expected = "x[0].a < x[1].a * (3 - x[2].a)" ,
8332- expected_py38 = "(x[0].a < (x[1].a * (3 - x[2].a)))" ,
83338317 ),
83348318 testcase (
83358319 expr = "a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0" ,
83368320 preface = ["_var0 = a.b" , "_var1 = _var0.c" ],
83378321 expected = "_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0" ,
8338- expected_py38 = "((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)" ,
83398322 ),
83408323 testcase (
83418324 expr = "f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512" ,
83428325 preface = ["_var2 = m.n" , "_var3 = _var2[0]" ],
83438326 expected = "f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512" ,
8344- expected_py38 = "(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)" ,
83458327 ),
83468328 testcase (
83478329 expr = "self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k" ,
83488330 preface = ["_var4 = self.g" , "_var5 = _var4(a, b)" , "_var6 = _var5.k" ],
83498331 expected = "_var6 + (1 - _var6) <= m[0].a + _var6" ,
8350- expected_py38 = "((_var6 + (1 - _var6)) <= (m[0].a + _var6))" ,
83518332 ),
83528333 ]
83538334
@@ -8357,7 +8338,7 @@ def test_guards_cse_pass_multiple(self):
83578338 for t in testcases :
83588339 preface , expr = csepass .replace (t .expr )
83598340 self .assertEqual (preface , t .preface )
8360- expected = t .expected_py38 if self . _is_py38 () else t . expected
8341+ expected = t .expected
83618342 expected = expected if expected is not None else t .expr
83628343 self .assertEqual (expr , expected )
83638344
@@ -8393,46 +8374,7 @@ def guard(L):
83938374 return True
83948375 return guard
83958376"""
8396- expected_38 = """\
8397- def ___make_guard_fn():
8398- def guard(L):
8399- if not ((x[0].a < (x[1].a * (3 - x[2].a)))):
8400- return False
8401- _var0 = a.b
8402- _var1 = _var0.c
8403- if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)):
8404- return False
8405- _var2 = m.n
8406- _var3 = _var2[0]
8407- if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)):
8408- return False
8409- _var4 = self.g
8410- _var5 = _var4(a, b)
8411- _var6 = _var5.k
8412- if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))):
8413- return False
8414- return True
8415- return guard
8416- """
8417- expected_38_no_astunparse = """\
8418- def ___make_guard_fn():
8419- def guard(L):
8420- if not (x[0].a < x[1].a * (3 - x[2].a)):
8421- return False
8422- if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0):
8423- return False
8424- if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512):
8425- return False
8426- if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k):
8427- return False
8428- return True
8429- return guard
8430- """
84318377
8432- if self ._is_py38 ():
8433- expected = (
8434- expected_38 if self ._has_ast_unparse () else expected_38_no_astunparse
8435- )
84368378 self .assertEqual (expected , pycode )
84378379
84388380 def test_dynamo_compiling_fake_tensor_to_vararg_int (self ):
0 commit comments