Skip to content

Commit 5a13282

Browse files
xmfanpytorchmergebot
authored andcommitted
[compiled autograd] tls access helpers (pytorch#138061)
Pull Request resolved: pytorch#138061 Approved by: https://github.com/yf225 ghstack dependencies: pytorch#137953, pytorch#137821
1 parent 49fa437 commit 5a13282

File tree

12 files changed

+20
-17
lines changed

12 files changed

+20
-17
lines changed

test/distributed/_composable/fsdp/test_fully_shard_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count):
256256
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
257257
)
258258

259-
if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
259+
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region():
260260
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
261261
else:
262262
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph

test/dynamo/test_activation_checkpointing.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def match_rng_op(node, op):
8686

8787

8888
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
89-
if not torch._dynamo.compiled_autograd.local.get(
90-
"in_compiled_autograd_region"
91-
): # fwd graph
89+
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fwd graph
9290
return_node = list(graph.nodes)[-1]
9391
assert return_node.target == "output"
9492
for x in return_node.args[0]:

test/inductor/test_compiled_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2412,7 +2412,7 @@ def train(errors, model, x):
24122412
try:
24132413
out = model(x)
24142414
with compiled_autograd.enable(compiler_fn):
2415-
self.assertEqual(compiled_autograd.local.enabled(), True)
2415+
self.assertEqual(compiled_autograd.enabled(), True)
24162416
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
24172417
except Exception as e:
24182418
print(f"Found error: {e}")

torch/_dynamo/compiled_autograd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ def revert():
8888

8989
return revert
9090

91-
def enabled(self) -> bool:
92-
return self.get("compiler") is not None
93-
9491
def enter_ctx(self) -> Callable[[], None]:
9592
state = self._get_tls()
9693
state.next_ctx_id += 1
@@ -127,6 +124,14 @@ def exit():
127124
local = TLSWrapper()
128125

129126

127+
def enabled() -> bool:
128+
return local.get("compiler") is not None
129+
130+
131+
def in_compiled_autograd_region() -> bool:
132+
return local.get("in_compiled_autograd_region")
133+
134+
130135
def maybe_clone(x):
131136
if x is not None:
132137
return clone_preserve_strides(x)

torch/_dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
30513051
if node.op == "placeholder" and node.meta.get("steal_arg", False)
30523052
]
30533053

3054-
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
3054+
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
30553055
# fast path, avoid pytree overhead
30563056
# compiled autograd inputs are always a list of tensors, maybe followed by symints
30573057
assert inputs_idx_to_clear == [0]

torch/_dynamo/variables/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def create(
313313
user_hooks: VariableTracker,
314314
user_pre_hooks: VariableTracker,
315315
):
316-
if not compiled_autograd.local.enabled():
316+
if not compiled_autograd.enabled():
317317
unimplemented("module-level backwards hooks require compiled autograd")
318318

319319
def _in_graph_bw_hooks(bw_state: BackwardState):

torch/_dynamo/variables/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def call_method(
929929
kwargs: "Dict[str, VariableTracker]",
930930
) -> "VariableTracker":
931931
if name == "queue_callback":
932-
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
932+
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
933933
assert (
934934
tx.one_graph
935935
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"

torch/_dynamo/variables/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker):
10071007
tx = InstructionTranslator.current_tx()
10081008

10091009
if not self.source:
1010-
if not compiled_autograd.local.enabled():
1010+
if not compiled_autograd.enabled():
10111011
# TODO(voz):
10121012
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
10131013
# python state.

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule):
176176
Checks that the graph module only uses supported operators
177177
"""
178178
nodes = gm.graph.nodes
179-
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
179+
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
180180
raise BypassAOTAutogradCache(
181181
"Cannot cache a graph with compiled autograd enabled"
182182
)

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def view_avoid_dupes_with_primals(t):
704704
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
705705
nonlocal static_input_indices
706706
static_input_indices = static_input_indices or []
707-
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
707+
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
708708
passed_indices = set(static_input_indices)
709709
static_input_indices = [
710710
i

0 commit comments

Comments
 (0)