Skip to content

Commit 49fa437

Browse files
xmfanpytorchmergebot
authored andcommitted
[compiled autograd] Compiled autograd configs in TLS (pytorch#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state Pull Request resolved: pytorch#137821 Approved by: https://github.com/jansel, https://github.com/yf225 ghstack dependencies: pytorch#137953
1 parent 7525914 commit 49fa437

File tree

16 files changed

+221
-103
lines changed

16 files changed

+221
-103
lines changed

test/distributed/_composable/fsdp/test_fully_shard_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count):
256256
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
257257
)
258258

259-
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
259+
if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
260260
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
261261
else:
262262
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph

test/dynamo/test_activation_checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def match_rng_op(node, op):
8686

8787

8888
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
89-
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
89+
if not torch._dynamo.compiled_autograd.local.get(
90+
"in_compiled_autograd_region"
91+
): # fwd graph
9092
return_node = list(graph.nodes)[-1]
9193
assert return_node.target == "output"
9294
for x in return_node.args[0]:

test/inductor/test_compiled_autograd.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import itertools
77
import logging
88
import os
9+
import queue
910
import re
1011
import subprocess
1112
import sys
13+
import threading
1214
import unittest
1315
from importlib.machinery import SourceFileLoader
1416
from pathlib import Path
@@ -2405,6 +2407,39 @@ def test_logs(self):
24052407
not in logs.getvalue()
24062408
)
24072409

2410+
def test_multithreading_tls(self):
2411+
def train(errors, model, x):
2412+
try:
2413+
out = model(x)
2414+
with compiled_autograd.enable(compiler_fn):
2415+
self.assertEqual(compiled_autograd.local.enabled(), True)
2416+
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
2417+
except Exception as e:
2418+
print(f"Found error: {e}")
2419+
errors.put(1)
2420+
raise
2421+
2422+
model = torch.nn.Sequential(
2423+
torch.nn.Linear(4, 4),
2424+
torch.nn.ReLU(),
2425+
torch.nn.Linear(4, 4),
2426+
torch.nn.ReLU(),
2427+
)
2428+
x = torch.randn([2, 4])
2429+
2430+
threads = []
2431+
errors = queue.Queue()
2432+
with compiled_autograd.enable(compiler_fn):
2433+
for i in range(4):
2434+
thread = threading.Thread(target=train, args=(errors, model, x))
2435+
threads.append(thread)
2436+
thread.start()
2437+
2438+
for thread in threads:
2439+
thread.join()
2440+
2441+
assert errors.empty()
2442+
24082443
def test_verbose_logs_graph(self):
24092444
def fn():
24102445
model = torch.nn.Sequential(
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from typing import Callable
22

3-
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
4-
5-
def set_autograd_compiler(
6-
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
7-
) -> Callable[[], AutogradCompilerInstance] | None: ...
3+
def notify_autograd_engine() -> None: ...
84
def clear_cache() -> None: ...
95
def is_cache_empty() -> bool: ...
106
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...

torch/_dynamo/compiled_autograd.py

Lines changed: 112 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# mypy: allow-untyped-defs
22
import contextlib
33
import functools
4-
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
4+
import threading
5+
from dataclasses import dataclass
6+
from logging import Logger
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
58

69
import torch
710
from torch._dynamo.external_utils import (
@@ -38,14 +41,90 @@
3841
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
3942

4043

41-
def snapshot_verbose_logging_enabled():
42-
return torch._logging._internal.log_state.is_artifact_enabled(
43-
"compiled_autograd_verbose"
44-
)
44+
@dataclass
45+
class CompiledAutogradTLS:
46+
next_ctx_id: int = 0
47+
in_compiled_autograd_region: bool = False
48+
compiler: Optional["AutogradCompilerInstance"] = None
49+
vlogger: Optional[Logger] = None
50+
51+
52+
class TLSWrapper:
53+
tls_key = "compiled_autograd_state"
54+
55+
def __init__(self):
56+
self._local = threading.local()
57+
58+
def _get_tls(self) -> CompiledAutogradTLS:
59+
if hasattr(self._local, self.tls_key):
60+
# first look in python
61+
state = getattr(self._local, self.tls_key)
62+
if torch._C._is_key_in_tls(self.tls_key):
63+
# then look in cpp
64+
state = torch._C._get_obj_in_tls(self.tls_key)
65+
else:
66+
# init new thread created outside of autograd
67+
# TODO: what if context manager wrapped outside of thread?
68+
setattr(self._local, self.tls_key, CompiledAutogradTLS())
69+
state = getattr(self._local, self.tls_key)
70+
torch._C._stash_obj_in_tls(self.tls_key, state)
71+
return state
72+
73+
# queries on the object stored in TLS
74+
def get(self, name):
75+
return getattr(self._get_tls(), name)
76+
77+
def set_tls(self, **kwargs) -> Callable[[], None]:
78+
priors: Dict[str, Any] = {}
79+
for k, v in kwargs.items():
80+
state = self._get_tls()
81+
priors[k] = getattr(state, k)
82+
setattr(state, k, v)
83+
84+
torch._C._dynamo.compiled_autograd.notify_autograd_engine()
85+
86+
def revert():
87+
self.set_tls(**priors)
88+
89+
return revert
90+
91+
def enabled(self) -> bool:
92+
return self.get("compiler") is not None
93+
94+
def enter_ctx(self) -> Callable[[], None]:
95+
state = self._get_tls()
96+
state.next_ctx_id += 1
97+
id = state.next_ctx_id
98+
99+
def exit():
100+
assert (
101+
state is self._get_tls()
102+
), "Runtime must begin and end on the same thread"
103+
assert state.next_ctx_id == id, (
104+
"Error nesting compiled autograd context managers: "
105+
"inner context managers must have shorter lifetime than the outer context manager"
106+
)
107+
state.next_ctx_id -= 1
108+
109+
return exit
110+
111+
def enter_compiled_region(self) -> Callable[[], None]:
112+
state = self._get_tls()
113+
prior = state.in_compiled_autograd_region
114+
state.in_compiled_autograd_region = True
115+
assert prior is False, "Nested compiled autograd regions are not supported"
116+
117+
def exit():
118+
assert (
119+
state is self._get_tls()
120+
), "Runtime must begin and end on the same thread"
121+
assert state.in_compiled_autograd_region is True
122+
state.in_compiled_autograd_region = prior
123+
124+
return exit
45125

46126

47-
def snapshot_cudagraph_enabled():
48-
return torch._inductor.config.triton.cudagraphs
127+
local = TLSWrapper()
49128

50129

51130
def maybe_clone(x):
@@ -307,7 +386,7 @@ def end_capture(self, outputs):
307386
self.rename_aot_dispatcher_nodes()
308387
self.reorder_accumulate_grad_nodes()
309388
runtime_inputs_to_move: List[int] = []
310-
if snapshot_cudagraph_enabled():
389+
if torch._inductor.config.triton.cudagraphs:
311390
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
312391

313392
graph = GraphModule(
@@ -329,16 +408,15 @@ def end_capture(self, outputs):
329408
)
330409

331410
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
332-
global in_compiled_autograd_region
333411
try:
334-
in_compiled_autograd_region = True
412+
exit_compiled_region = local.enter_compiled_region()
335413
for i in runtime_inputs_to_move:
336414
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
337415

338416
with disable():
339417
return compiled_fn(inputs, sizes, scalars, hooks)
340418
finally:
341-
in_compiled_autograd_region = False
419+
exit_compiled_region()
342420

343421
return runtime_wrapper, self.compiler_fn(graph)
344422

@@ -510,15 +588,9 @@ def set_node_origin(
510588
set_stack_trace(new_stack_trace)
511589

512590

513-
# state of the autograd engine dispatch, kept in sync by enable/disable context managers
514-
compiled_autograd_enabled = False
515-
516591
# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
517592
compiled_autograd_enabled_force_eager = False
518593

519-
# global flag to check if we are processing graphs produced from a compiled autograd graph
520-
in_compiled_autograd_region = False
521-
522594

523595
@contextlib.contextmanager
524596
def enable(compiler_fn):
@@ -538,39 +610,42 @@ def enable(compiler_fn):
538610
# we need to lazily import it, because of circular dependencies
539611
import torch._inductor.cudagraph_trees
540612

541-
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
542-
functools.partial(AutogradCompilerInstance, compiler_fn)
613+
exit_ctx = local.enter_ctx()
614+
revert_tls = local.set_tls(
615+
compiler=functools.partial(AutogradCompilerInstance, compiler_fn),
616+
vlogger=verbose_log
617+
if torch._logging._internal.log_state.is_artifact_enabled(
618+
"compiled_autograd_verbose"
619+
)
620+
else None,
543621
)
544-
if snapshot_verbose_logging_enabled():
545-
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log)
546-
global compiled_autograd_enabled
547-
compiled_autograd_enabled = True
548622
try:
549623
with torch.autograd.set_multithreading_enabled(False):
550624
yield
551625
finally:
552-
if not prior:
553-
compiled_autograd_enabled = False
554-
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
626+
revert_tls()
627+
exit_ctx()
555628

556629

557630
@contextlib.contextmanager
558631
def disable():
559-
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
560-
global compiled_autograd_enabled
561-
compiled_autograd_enabled = False
632+
exit_ctx = local.enter_ctx()
633+
revert_tls = local.set_tls(
634+
compiler=None,
635+
vlogger=None,
636+
)
562637
try:
563638
yield
564639
finally:
565-
if prior:
566-
compiled_autograd_enabled = True
567-
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
640+
revert_tls()
641+
exit_ctx()
568642

569643

570644
# return to starting state of a new process
571645
def reset() -> None:
572-
global compiled_autograd_enabled
573-
compiled_autograd_enabled = False
574-
assert not in_compiled_autograd_region
575-
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
576-
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
646+
assert local.get("next_ctx_id") == 0
647+
assert local.get("in_compiled_autograd_region") is False
648+
local.set_tls(
649+
compiler=None,
650+
vlogger=None,
651+
)

torch/_dynamo/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,11 @@ def default_debug_dir_root():
472472
# Overrides torch.compile() kwargs for Compiled Autograd:
473473
compiled_autograd_kwargs_override: Dict[str, Any] = {}
474474

475+
# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph,
476+
# and make them opaque to the compiler. This does not work when the C++ backward implementation involves
477+
# other dispatcher subsystems e.g. custom subclasses, autocast, vmap.
478+
compiled_autograd_opaque_cpp_node = False
479+
475480
# Enables use of collectives *during* compilation to synchronize behavior
476481
# across ranks. Today, this is used solely to modify automatic_dynamic_shapes
477482
# behavior, making it so that we infer that if an input is dynamic by

torch/_dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
30513051
if node.op == "placeholder" and node.meta.get("steal_arg", False)
30523052
]
30533053

3054-
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
3054+
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
30553055
# fast path, avoid pytree overhead
30563056
# compiled autograd inputs are always a list of tensors, maybe followed by symints
30573057
assert inputs_idx_to_clear == [0]

torch/_dynamo/variables/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def create(
313313
user_hooks: VariableTracker,
314314
user_pre_hooks: VariableTracker,
315315
):
316-
if not compiled_autograd.compiled_autograd_enabled:
316+
if not compiled_autograd.local.enabled():
317317
unimplemented("module-level backwards hooks require compiled autograd")
318318

319319
def _in_graph_bw_hooks(bw_state: BackwardState):

torch/_dynamo/variables/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def call_method(
929929
kwargs: "Dict[str, VariableTracker]",
930930
) -> "VariableTracker":
931931
if name == "queue_callback":
932-
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
932+
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
933933
assert (
934934
tx.one_graph
935935
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"

torch/_dynamo/variables/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker):
10071007
tx = InstructionTranslator.current_tx()
10081008

10091009
if not self.source:
1010-
if not compiled_autograd.compiled_autograd_enabled:
1010+
if not compiled_autograd.local.enabled():
10111011
# TODO(voz):
10121012
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
10131013
# python state.

0 commit comments

Comments
 (0)