Skip to content

Commit 3474367

Browse files
mlazospytorchmergebot
authored andcommitted
[Dynamo] Cleanup state management for ctx managers (pytorch#149689)
Removes state indirection for ctx managers. This isn't needed anymore since VTs are mutable. Pull Request resolved: pytorch#149689 Approved by: https://github.com/StrongerXi
1 parent cfc08ca commit 3474367

File tree

1 file changed

+42
-58
lines changed

1 file changed

+42
-58
lines changed

torch/_dynamo/variables/ctx_manager.py

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
restoring state changes.
2121
"""
2222

23-
import dataclasses
2423
import inspect
2524
import sys
2625
import warnings
27-
from typing import Callable, Optional, TYPE_CHECKING, Union
26+
from typing import TYPE_CHECKING, Union
2827

2928
import torch._C
3029
from torch._guards import Guard
@@ -54,27 +53,6 @@
5453
from torch._dynamo.symbolic_convert import InstructionTranslator
5554

5655

57-
@dataclasses.dataclass
58-
class ContextManagerState:
59-
"""
60-
Mutating `self` in VariableTracker is not allowed because we copy
61-
them. This is a mutable container pointed to by context managers
62-
that won't get copied, so it is safe to mutate.
63-
"""
64-
65-
cleanup_fn: Optional[Callable] = None
66-
proxy: Optional[torch.fx.Proxy] = None
67-
68-
def cleanup(self):
69-
if self.cleanup_fn is not None:
70-
self.cleanup_fn()
71-
self.cleanup_fn = None
72-
73-
def cleanup_assert(self):
74-
assert self.cleanup_fn, "multiple exits?"
75-
self.cleanup()
76-
77-
7856
class ContextWrappingVariable(VariableTracker):
7957
_nonvar_fields = {
8058
"cm_obj",
@@ -84,13 +62,10 @@ class ContextWrappingVariable(VariableTracker):
8462
*VariableTracker._nonvar_fields,
8563
}
8664

87-
def __init__(
88-
self, target_values, initial_values=None, *, state=None, **kwargs
89-
) -> None:
65+
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
9066
super().__init__(**kwargs)
9167
self.target_values = target_values
9268
self.initial_values = initial_values
93-
self.state = ContextManagerState() if state is None else state
9469

9570
def enter(self, tx):
9671
self._call_func(tx, self.target_values)
@@ -103,11 +78,11 @@ def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
10378
def fn():
10479
self._call_func(tx, self.initial_values)
10580

106-
self.state.cleanup_fn = fn
107-
tx.output.add_cleanup_hook(self.state.cleanup)
81+
self.cleanup_fn = fn
82+
tx.output.add_cleanup_hook(self.cleanup)
10883

10984
def exit(self, tx: "InstructionTranslator", *args):
110-
self.state.cleanup_assert()
85+
self.cleanup_assert()
11186
return variables.ConstantVariable.create(None)
11287

11388
def reconstruct_type(self, codegen):
@@ -152,6 +127,15 @@ def supports_graph_breaks(self):
152127
def exit_on_graph_break(self):
153128
return True
154129

130+
def cleanup(self):
131+
if self.cleanup_fn is not None:
132+
self.cleanup_fn()
133+
self.cleanup_fn = None
134+
135+
def cleanup_assert(self):
136+
assert self.cleanup_fn, "multiple exits?"
137+
self.cleanup()
138+
155139

156140
class GenericContextWrappingVariable(UserDefinedObjectVariable):
157141
# Some methods in ContextWrappingVariable assumes the arguments are
@@ -217,7 +201,7 @@ def enter(self, tx):
217201
self.prev_state
218202
),
219203
)
220-
self.state.proxy = tx.output.create_node(
204+
self.proxy = tx.output.create_node(
221205
"call_function",
222206
torch._C._functorch.set_inplace_requires_grad_allowed,
223207
(enabled,),
@@ -226,7 +210,7 @@ def enter(self, tx):
226210
return variables.ConstantVariable.create(None)
227211

228212
def exit(self, tx: "InstructionTranslator", *args):
229-
self.state.cleanup()
213+
self.cleanup()
230214
tx.output.create_node(
231215
"call_function",
232216
torch._C._functorch.set_inplace_requires_grad_allowed,
@@ -253,7 +237,7 @@ def enter(self, tx):
253237
tx,
254238
lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved),
255239
)
256-
self.state.proxy = tx.output.create_node(
240+
self.proxy = tx.output.create_node(
257241
"call_function",
258242
torch._C._functorch.pop_dynamic_layer_stack,
259243
(),
@@ -262,11 +246,11 @@ def enter(self, tx):
262246
return variables.ConstantVariable.create(None)
263247

264248
def exit(self, tx: "InstructionTranslator", *args):
265-
self.state.cleanup()
249+
self.cleanup()
266250
tx.output.create_node(
267251
"call_function",
268252
torch._C._functorch.push_dynamic_layer_stack,
269-
(self.state.proxy,),
253+
(self.proxy,),
270254
{},
271255
)
272256
return variables.ConstantVariable.create(None)
@@ -297,7 +281,7 @@ def enter(self, tx):
297281
self.set_cleanup_hook(
298282
tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
299283
)
300-
self.state.proxy = tx.output.create_node(
284+
self.proxy = tx.output.create_node(
301285
"call_function",
302286
torch._C._functorch._jvp_increment_nesting,
303287
(),
@@ -306,7 +290,7 @@ def enter(self, tx):
306290
return variables.ConstantVariable.create(jvp_level)
307291

308292
def exit(self, tx: "InstructionTranslator", *args):
309-
self.state.cleanup()
293+
self.cleanup()
310294
tx.output.create_node(
311295
"call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
312296
)
@@ -332,7 +316,7 @@ def enter(self, tx):
332316
tx,
333317
lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
334318
)
335-
self.state.proxy = tx.output.create_node(
319+
self.proxy = tx.output.create_node(
336320
"call_function",
337321
torch._C._set_fwd_grad_enabled,
338322
(mode,),
@@ -341,7 +325,7 @@ def enter(self, tx):
341325
return variables.ConstantVariable.create(None)
342326

343327
def exit(self, tx: "InstructionTranslator", *args):
344-
self.state.cleanup()
328+
self.cleanup()
345329
tx.output.create_node(
346330
"call_function",
347331
torch._C._set_fwd_grad_enabled,
@@ -370,7 +354,7 @@ def enter(self, tx):
370354
self.set_cleanup_hook(
371355
tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
372356
)
373-
self.state.proxy = tx.output.create_node(
357+
self.proxy = tx.output.create_node(
374358
"call_function",
375359
torch._C._enter_dual_level,
376360
(),
@@ -379,7 +363,7 @@ def enter(self, tx):
379363
return variables.ConstantVariable.create(self.new_level)
380364

381365
def exit(self, tx: "InstructionTranslator", *args):
382-
self.state.cleanup()
366+
self.cleanup()
383367
tx.output.create_node(
384368
"call_function",
385369
torch._C._exit_dual_level,
@@ -412,7 +396,7 @@ def enter(self, tx):
412396
install_guard(self._guards_singleton)
413397
grad_level = torch._C._functorch._grad_increment_nesting()
414398
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
415-
self.state.proxy = tx.output.create_node(
399+
self.proxy = tx.output.create_node(
416400
"call_function",
417401
torch._C._functorch._grad_increment_nesting,
418402
(),
@@ -421,7 +405,7 @@ def enter(self, tx):
421405
return variables.ConstantVariable.create(grad_level)
422406

423407
def exit(self, tx: "InstructionTranslator", *args):
424-
self.state.cleanup()
408+
self.cleanup()
425409
tx.output.create_node(
426410
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
427411
)
@@ -492,7 +476,7 @@ def enter(self, tx):
492476
batch_size_value, randomness
493477
)
494478
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
495-
self.state.proxy = tx.output.create_node(
479+
self.proxy = tx.output.create_node(
496480
"call_function",
497481
torch._C._functorch._vmap_increment_nesting,
498482
(batch_size_node, randomness),
@@ -501,7 +485,7 @@ def enter(self, tx):
501485
return variables.ConstantVariable.create(vmap_level)
502486

503487
def exit(self, tx: "InstructionTranslator", *args):
504-
self.state.cleanup()
488+
self.cleanup()
505489
tx.output.create_node(
506490
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
507491
)
@@ -589,11 +573,11 @@ def __init__(
589573
self.target_values = target_values
590574

591575
def exit(self, tx: "InstructionTranslator", *args):
592-
self.state.cleanup_assert()
576+
self.cleanup_assert()
593577
tx.output.create_node(
594578
"call_function",
595579
torch.autograd.grad_mode._exit_inference_mode,
596-
(self.state.proxy,),
580+
(self.proxy,),
597581
{},
598582
)
599583

@@ -619,7 +603,7 @@ def cleanup_hook():
619603
torch.autograd.grad_mode._exit_inference_mode(ctx)
620604

621605
self.set_cleanup_hook(tx, cleanup_hook)
622-
self.state.proxy = tx.output.create_node(
606+
self.proxy = tx.output.create_node(
623607
"call_function",
624608
torch.autograd.grad_mode._enter_inference_mode,
625609
(*self.target_values,),
@@ -657,19 +641,19 @@ def __init__(
657641
self.target_values = target_values
658642

659643
def exit(self, tx: "InstructionTranslator", *args):
660-
self.state.cleanup_assert()
644+
self.cleanup_assert()
661645
tx.output.create_node(
662646
"call_function",
663647
torch.cuda._maybe_exchange_device,
664-
(self.state.proxy,),
648+
(self.proxy,),
665649
{},
666650
)
667651
return variables.ConstantVariable.create(False)
668652

669653
def enter(self, tx):
670654
prev_idx = torch.cuda._exchange_device(*self.target_values)
671655
self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
672-
self.state.proxy = tx.output.create_node(
656+
self.proxy = tx.output.create_node(
673657
"call_function",
674658
torch.cuda._exchange_device,
675659
(*self.target_values,),
@@ -730,8 +714,8 @@ def fn():
730714
self.initial_torch_function_subclass_enabled
731715
)
732716

733-
self.state.cleanup_fn = fn
734-
tx.output.add_cleanup_hook(self.state.cleanup)
717+
self.cleanup_fn = fn
718+
tx.output.add_cleanup_hook(self.cleanup)
735719

736720
def _call_func(self, tx: "InstructionTranslator", values):
737721
assert len(values) == 0
@@ -885,15 +869,15 @@ def __init__(self, target_values, initial_values=None, **kwargs) -> None:
885869
self.target_values = target_values
886870

887871
def exit(self, tx: "InstructionTranslator", *args):
888-
self.state.cleanup_assert()
872+
self.cleanup_assert()
889873
tx.output.create_node(
890-
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
874+
"call_function", torch.amp._exit_autocast, (self.proxy,), {}
891875
)
892876

893877
def enter(self, tx):
894878
ctx = torch.amp._enter_autocast(*self.target_values)
895879
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
896-
self.state.proxy = tx.output.create_node(
880+
self.proxy = tx.output.create_node(
897881
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
898882
)
899883

@@ -1021,7 +1005,7 @@ def exit(self, tx: "InstructionTranslator", *args):
10211005
(self.initial_values[0].as_proxy(),),
10221006
{},
10231007
)
1024-
self.state.cleanup_assert()
1008+
self.cleanup_assert()
10251009

10261010

10271011
class PreserveVersionContextVariable(ContextWrappingVariable):
@@ -1212,7 +1196,7 @@ def enter(self, tx):
12121196
return variables.ConstantVariable.create(None)
12131197

12141198
def exit(self, tx: "InstructionTranslator", *args):
1215-
self.state.cleanup_assert()
1199+
self.cleanup_assert()
12161200
arg = self._backends_to_nodes(tx, self.prev_backends)
12171201
tx.output.create_node(
12181202
"call_function",

0 commit comments

Comments
 (0)