@@ -1748,18 +1748,17 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
17481748 l_y_ = L_y_
17491749 map_body_1 = self.map_body_1
17501750 map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
1751- getitem_1 = map_impl[0]; map_impl = None
1752- return (getitem_1 ,)""" ,
1751+ getitem = map_impl[0]; map_impl = None
1752+ return (getitem ,)""" ,
17531753 )
17541754 self .assertExpectedInline (
17551755 body_graph ,
17561756 """\
17571757 def forward(self, child : torch.Tensor, l_y_ : torch.Tensor):
1758- child_1 = child[0]; child_1 = None
17591758 map_body_0 = self.map_body_0
17601759 map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None
1761- getitem_1 = map_impl[0]; map_impl = None
1762- return (getitem_1 ,)""" ,
1760+ getitem = map_impl[0]; map_impl = None
1761+ return (getitem ,)""" ,
17631762 )
17641763
17651764 def test_map_multi_return (self ):
@@ -1777,9 +1776,9 @@ def forward(self, L_x_ : torch.Tensor):
17771776 l_x_ = L_x_
17781777 map_body_0 = self.map_body_0
17791778 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1780- getitem_1 = map_impl[0]
1781- getitem_2 = map_impl[1]; map_impl = None
1782- return (getitem_1, getitem_2 )""" ,
1779+ getitem = map_impl[0]
1780+ getitem_1 = map_impl[1]; map_impl = None
1781+ return (getitem, getitem_1 )""" ,
17831782 )
17841783 self .assertExpectedInline (
17851784 body_graph ,
@@ -1811,14 +1810,14 @@ def forward(self, L_x_ : torch.Tensor):
18111810 l_x_ = L_x_
18121811 map_body_0 = self.map_body_0
18131812 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
1814- getitem_1 = map_impl[0]
1815- getitem_2 = map_impl[1]
1816- getitem_3 = map_impl[2]
1817- getitem_4 = map_impl[3]
1818- getitem_5 = map_impl[4]
1819- getitem_6 = map_impl[5]
1813+ getitem = map_impl[0]
1814+ getitem_1 = map_impl[1]
1815+ getitem_2 = map_impl[2]
1816+ getitem_3 = map_impl[3]
1817+ getitem_4 = map_impl[4]
1818+ getitem_5 = map_impl[5]
18201819 value = map_impl[6]; map_impl = None
1821- return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6 , value)""" ,
1820+ return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, value)""" ,
18221821 )
18231822 self .assertExpectedInline (
18241823 body_graph ,
@@ -1857,8 +1856,8 @@ def forward(self, L_x_ : torch.Tensor):
18571856 l_x_ = L_x_
18581857 map_body_0 = self.map_body_0
18591858 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1860- getitem_1 = map_impl[0]; map_impl = None
1861- return (getitem_1 ,)""" ,
1859+ getitem = map_impl[0]; map_impl = None
1860+ return (getitem ,)""" ,
18621861 )
18631862 self .assertExpectedInline (
18641863 body_graph ,
@@ -1888,8 +1887,8 @@ def forward(self, L_x_ : torch.Tensor):
18881887 l_x_ = L_x_
18891888 map_body_0 = self.map_body_0
18901889 map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
1891- getitem_1 = map_impl[0]; map_impl = None
1892- return (getitem_1 ,)""" ,
1890+ getitem = map_impl[0]; map_impl = None
1891+ return (getitem ,)""" ,
18931892 )
18941893 self .assertExpectedInline (
18951894 body_graph ,
@@ -2279,15 +2278,12 @@ def body(x):
22792278 mod = Module ()
22802279
22812280 mod_for_compile = torch .compile (mod , backend = cnt , dynamic = True , fullgraph = False )
2282- mod_for_eager = Module ()
22832281
2284- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2285- # There is graph break right when we enter body of map
2286- # Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
2287- self .assertEqual (len (backend .graphs ), 8 )
2288- self .assertEqual (
2289- res , mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2290- )
2282+ with self .assertRaisesRegex (
2283+ torch ._dynamo .exc .UncapturedHigherOrderOpError ,
2284+ "map doesn't work unless it is captured completely with torch.compile" ,
2285+ ):
2286+ mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
22912287
22922288 def test_map_side_effect (self ):
22932289 backend = EagerAndRecordGraphs ()
@@ -2312,17 +2308,12 @@ def body(x):
23122308 mod = Module ()
23132309
23142310 mod_for_compile = torch .compile (mod , backend = cnt , dynamic = True , fullgraph = False )
2315- mod_for_eager = Module ()
2316-
2317- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2318- res = mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
23192311
2320- eager = mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2321- eager = mod_for_eager (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
2322-
2323- # Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
2324- self .assertEqual (len (backend .graphs ), 9 )
2325- self .assertEqual (res , eager )
2312+ with self .assertRaisesRegex (
2313+ torch ._dynamo .exc .UncapturedHigherOrderOpError ,
2314+ "map doesn't work unless it is captured completely with torch.compile" ,
2315+ ):
2316+ mod_for_compile (torch .Tensor ([[6 , 4 , 5 ], [3 , 4 , 5 ], [6 , 6 , 6 ]]))
23262317
23272318 def test_wrap_subgraph_name_is_valid (self ):
23282319 backend = EagerAndRecordGraphs ()
@@ -2923,7 +2914,10 @@ def inner2(x, y):
29232914 actual_stack = self ._get_source_fn_stack (gm , {"cos" , "add" , "sin" })
29242915 self .assertExpectedInline (
29252916 pprint .pformat (actual_stack ),
2926- """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""" ,
2917+ """\
2918+ {'add': ['map_impl', 'map_impl', 'add'],
2919+ 'cos': ['map_impl', 'cos'],
2920+ 'sin': ['sin']}""" ,
29272921 )
29282922
29292923 def test_grad_source_fn_stack (self ):
0 commit comments