2020restoring state changes.
2121"""
2222
23- import dataclasses
2423import inspect
2524import sys
2625import warnings
27- from typing import Callable , Optional , TYPE_CHECKING , Union
26+ from typing import TYPE_CHECKING , Union
2827
2928import torch ._C
3029from torch ._guards import Guard
5453 from torch ._dynamo .symbolic_convert import InstructionTranslator
5554
5655
57- @dataclasses .dataclass
58- class ContextManagerState :
59- """
60- Mutating `self` in VariableTracker is not allowed because we copy
61- them. This is a mutable container pointed to by context managers
62- that won't get copied, so it is safe to mutate.
63- """
64-
65- cleanup_fn : Optional [Callable ] = None
66- proxy : Optional [torch .fx .Proxy ] = None
67-
68- def cleanup (self ):
69- if self .cleanup_fn is not None :
70- self .cleanup_fn ()
71- self .cleanup_fn = None
72-
73- def cleanup_assert (self ):
74- assert self .cleanup_fn , "multiple exits?"
75- self .cleanup ()
76-
77-
7856class ContextWrappingVariable (VariableTracker ):
7957 _nonvar_fields = {
8058 "cm_obj" ,
@@ -84,13 +62,10 @@ class ContextWrappingVariable(VariableTracker):
8462 * VariableTracker ._nonvar_fields ,
8563 }
8664
87- def __init__ (
88- self , target_values , initial_values = None , * , state = None , ** kwargs
89- ) -> None :
65+ def __init__ (self , target_values , initial_values = None , ** kwargs ) -> None :
9066 super ().__init__ (** kwargs )
9167 self .target_values = target_values
9268 self .initial_values = initial_values
93- self .state = ContextManagerState () if state is None else state
9469
9570 def enter (self , tx ):
9671 self ._call_func (tx , self .target_values )
@@ -103,11 +78,11 @@ def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
10378 def fn ():
10479 self ._call_func (tx , self .initial_values )
10580
106- self .state . cleanup_fn = fn
107- tx .output .add_cleanup_hook (self .state . cleanup )
81+ self .cleanup_fn = fn
82+ tx .output .add_cleanup_hook (self .cleanup )
10883
10984 def exit (self , tx : "InstructionTranslator" , * args ):
110- self .state . cleanup_assert ()
85+ self .cleanup_assert ()
11186 return variables .ConstantVariable .create (None )
11287
11388 def reconstruct_type (self , codegen ):
@@ -152,6 +127,15 @@ def supports_graph_breaks(self):
152127 def exit_on_graph_break (self ):
153128 return True
154129
130+ def cleanup (self ):
131+ if self .cleanup_fn is not None :
132+ self .cleanup_fn ()
133+ self .cleanup_fn = None
134+
135+ def cleanup_assert (self ):
136+ assert self .cleanup_fn , "multiple exits?"
137+ self .cleanup ()
138+
155139
156140class GenericContextWrappingVariable (UserDefinedObjectVariable ):
157141 # Some methods in ContextWrappingVariable assumes the arguments are
@@ -217,7 +201,7 @@ def enter(self, tx):
217201 self .prev_state
218202 ),
219203 )
220- self .state . proxy = tx .output .create_node (
204+ self .proxy = tx .output .create_node (
221205 "call_function" ,
222206 torch ._C ._functorch .set_inplace_requires_grad_allowed ,
223207 (enabled ,),
@@ -226,7 +210,7 @@ def enter(self, tx):
226210 return variables .ConstantVariable .create (None )
227211
228212 def exit (self , tx : "InstructionTranslator" , * args ):
229- self .state . cleanup ()
213+ self .cleanup ()
230214 tx .output .create_node (
231215 "call_function" ,
232216 torch ._C ._functorch .set_inplace_requires_grad_allowed ,
@@ -253,7 +237,7 @@ def enter(self, tx):
253237 tx ,
254238 lambda : torch ._C ._functorch .push_dynamic_layer_stack (self .saved ),
255239 )
256- self .state . proxy = tx .output .create_node (
240+ self .proxy = tx .output .create_node (
257241 "call_function" ,
258242 torch ._C ._functorch .pop_dynamic_layer_stack ,
259243 (),
@@ -262,11 +246,11 @@ def enter(self, tx):
262246 return variables .ConstantVariable .create (None )
263247
264248 def exit (self , tx : "InstructionTranslator" , * args ):
265- self .state . cleanup ()
249+ self .cleanup ()
266250 tx .output .create_node (
267251 "call_function" ,
268252 torch ._C ._functorch .push_dynamic_layer_stack ,
269- (self .state . proxy ,),
253+ (self .proxy ,),
270254 {},
271255 )
272256 return variables .ConstantVariable .create (None )
@@ -297,7 +281,7 @@ def enter(self, tx):
297281 self .set_cleanup_hook (
298282 tx , lambda : torch ._functorch .eager_transforms .exit_jvp_nesting ()
299283 )
300- self .state . proxy = tx .output .create_node (
284+ self .proxy = tx .output .create_node (
301285 "call_function" ,
302286 torch ._C ._functorch ._jvp_increment_nesting ,
303287 (),
@@ -306,7 +290,7 @@ def enter(self, tx):
306290 return variables .ConstantVariable .create (jvp_level )
307291
308292 def exit (self , tx : "InstructionTranslator" , * args ):
309- self .state . cleanup ()
293+ self .cleanup ()
310294 tx .output .create_node (
311295 "call_function" , torch ._C ._functorch ._jvp_decrement_nesting , (), {}
312296 )
@@ -332,7 +316,7 @@ def enter(self, tx):
332316 tx ,
333317 lambda : torch ._C ._set_fwd_grad_enabled (self .prev_state ),
334318 )
335- self .state . proxy = tx .output .create_node (
319+ self .proxy = tx .output .create_node (
336320 "call_function" ,
337321 torch ._C ._set_fwd_grad_enabled ,
338322 (mode ,),
@@ -341,7 +325,7 @@ def enter(self, tx):
341325 return variables .ConstantVariable .create (None )
342326
343327 def exit (self , tx : "InstructionTranslator" , * args ):
344- self .state . cleanup ()
328+ self .cleanup ()
345329 tx .output .create_node (
346330 "call_function" ,
347331 torch ._C ._set_fwd_grad_enabled ,
@@ -370,7 +354,7 @@ def enter(self, tx):
370354 self .set_cleanup_hook (
371355 tx , lambda : torch .autograd .forward_ad .exit_dual_level (level = self .new_level )
372356 )
373- self .state . proxy = tx .output .create_node (
357+ self .proxy = tx .output .create_node (
374358 "call_function" ,
375359 torch ._C ._enter_dual_level ,
376360 (),
@@ -379,7 +363,7 @@ def enter(self, tx):
379363 return variables .ConstantVariable .create (self .new_level )
380364
381365 def exit (self , tx : "InstructionTranslator" , * args ):
382- self .state . cleanup ()
366+ self .cleanup ()
383367 tx .output .create_node (
384368 "call_function" ,
385369 torch ._C ._exit_dual_level ,
@@ -412,7 +396,7 @@ def enter(self, tx):
412396 install_guard (self ._guards_singleton )
413397 grad_level = torch ._C ._functorch ._grad_increment_nesting ()
414398 self .set_cleanup_hook (tx , lambda : torch ._C ._functorch ._grad_decrement_nesting ())
415- self .state . proxy = tx .output .create_node (
399+ self .proxy = tx .output .create_node (
416400 "call_function" ,
417401 torch ._C ._functorch ._grad_increment_nesting ,
418402 (),
@@ -421,7 +405,7 @@ def enter(self, tx):
421405 return variables .ConstantVariable .create (grad_level )
422406
423407 def exit (self , tx : "InstructionTranslator" , * args ):
424- self .state . cleanup ()
408+ self .cleanup ()
425409 tx .output .create_node (
426410 "call_function" , torch ._C ._functorch ._grad_decrement_nesting , (), {}
427411 )
@@ -492,7 +476,7 @@ def enter(self, tx):
492476 batch_size_value , randomness
493477 )
494478 self .set_cleanup_hook (tx , lambda : torch ._C ._functorch ._vmap_decrement_nesting ())
495- self .state . proxy = tx .output .create_node (
479+ self .proxy = tx .output .create_node (
496480 "call_function" ,
497481 torch ._C ._functorch ._vmap_increment_nesting ,
498482 (batch_size_node , randomness ),
@@ -501,7 +485,7 @@ def enter(self, tx):
501485 return variables .ConstantVariable .create (vmap_level )
502486
503487 def exit (self , tx : "InstructionTranslator" , * args ):
504- self .state . cleanup ()
488+ self .cleanup ()
505489 tx .output .create_node (
506490 "call_function" , torch ._C ._functorch ._vmap_decrement_nesting , (), {}
507491 )
@@ -589,11 +573,11 @@ def __init__(
589573 self .target_values = target_values
590574
591575 def exit (self , tx : "InstructionTranslator" , * args ):
592- self .state . cleanup_assert ()
576+ self .cleanup_assert ()
593577 tx .output .create_node (
594578 "call_function" ,
595579 torch .autograd .grad_mode ._exit_inference_mode ,
596- (self .state . proxy ,),
580+ (self .proxy ,),
597581 {},
598582 )
599583
@@ -619,7 +603,7 @@ def cleanup_hook():
619603 torch .autograd .grad_mode ._exit_inference_mode (ctx )
620604
621605 self .set_cleanup_hook (tx , cleanup_hook )
622- self .state . proxy = tx .output .create_node (
606+ self .proxy = tx .output .create_node (
623607 "call_function" ,
624608 torch .autograd .grad_mode ._enter_inference_mode ,
625609 (* self .target_values ,),
@@ -657,19 +641,19 @@ def __init__(
657641 self .target_values = target_values
658642
659643 def exit (self , tx : "InstructionTranslator" , * args ):
660- self .state . cleanup_assert ()
644+ self .cleanup_assert ()
661645 tx .output .create_node (
662646 "call_function" ,
663647 torch .cuda ._maybe_exchange_device ,
664- (self .state . proxy ,),
648+ (self .proxy ,),
665649 {},
666650 )
667651 return variables .ConstantVariable .create (False )
668652
669653 def enter (self , tx ):
670654 prev_idx = torch .cuda ._exchange_device (* self .target_values )
671655 self .set_cleanup_hook (tx , lambda : torch .cuda ._maybe_exchange_device (prev_idx ))
672- self .state . proxy = tx .output .create_node (
656+ self .proxy = tx .output .create_node (
673657 "call_function" ,
674658 torch .cuda ._exchange_device ,
675659 (* self .target_values ,),
@@ -730,8 +714,8 @@ def fn():
730714 self .initial_torch_function_subclass_enabled
731715 )
732716
733- self .state . cleanup_fn = fn
734- tx .output .add_cleanup_hook (self .state . cleanup )
717+ self .cleanup_fn = fn
718+ tx .output .add_cleanup_hook (self .cleanup )
735719
736720 def _call_func (self , tx : "InstructionTranslator" , values ):
737721 assert len (values ) == 0
@@ -885,15 +869,15 @@ def __init__(self, target_values, initial_values=None, **kwargs) -> None:
885869 self .target_values = target_values
886870
887871 def exit (self , tx : "InstructionTranslator" , * args ):
888- self .state . cleanup_assert ()
872+ self .cleanup_assert ()
889873 tx .output .create_node (
890- "call_function" , torch .amp ._exit_autocast , (self .state . proxy ,), {}
874+ "call_function" , torch .amp ._exit_autocast , (self .proxy ,), {}
891875 )
892876
893877 def enter (self , tx ):
894878 ctx = torch .amp ._enter_autocast (* self .target_values )
895879 self .set_cleanup_hook (tx , lambda : torch .amp ._exit_autocast (ctx ))
896- self .state . proxy = tx .output .create_node (
880+ self .proxy = tx .output .create_node (
897881 "call_function" , torch .amp ._enter_autocast , (* self .target_values ,), {}
898882 )
899883
@@ -1021,7 +1005,7 @@ def exit(self, tx: "InstructionTranslator", *args):
10211005 (self .initial_values [0 ].as_proxy (),),
10221006 {},
10231007 )
1024- self .state . cleanup_assert ()
1008+ self .cleanup_assert ()
10251009
10261010
10271011class PreserveVersionContextVariable (ContextWrappingVariable ):
@@ -1212,7 +1196,7 @@ def enter(self, tx):
12121196 return variables .ConstantVariable .create (None )
12131197
12141198 def exit (self , tx : "InstructionTranslator" , * args ):
1215- self .state . cleanup_assert ()
1199+ self .cleanup_assert ()
12161200 arg = self ._backends_to_nodes (tx , self .prev_backends )
12171201 tx .output .create_node (
12181202 "call_function" ,
0 commit comments