11# mypy: allow-untyped-defs
22import contextlib
33import functools
4- from typing import Any , Dict , List , Optional , Tuple , TYPE_CHECKING , Union
4+ import threading
5+ from dataclasses import dataclass
6+ from logging import Logger
7+ from typing import Any , Callable , Dict , List , Optional , Tuple , TYPE_CHECKING , Union
58
69import torch
710from torch ._dynamo .external_utils import (
3841verbose_log = getArtifactLogger (__name__ , "compiled_autograd_verbose" )
3942
4043
41- def snapshot_verbose_logging_enabled ():
42- return torch ._logging ._internal .log_state .is_artifact_enabled (
43- "compiled_autograd_verbose"
44- )
44+ @dataclass
45+ class CompiledAutogradTLS :
46+ next_ctx_id : int = 0
47+ in_compiled_autograd_region : bool = False
48+ compiler : Optional ["AutogradCompilerInstance" ] = None
49+ vlogger : Optional [Logger ] = None
50+
51+
52+ class TLSWrapper :
53+ tls_key = "compiled_autograd_state"
54+
55+ def __init__ (self ):
56+ self ._local = threading .local ()
57+
58+ def _get_tls (self ) -> CompiledAutogradTLS :
59+ if hasattr (self ._local , self .tls_key ):
60+ # first look in python
61+ state = getattr (self ._local , self .tls_key )
62+ if torch ._C ._is_key_in_tls (self .tls_key ):
63+ # then look in cpp
64+ state = torch ._C ._get_obj_in_tls (self .tls_key )
65+ else :
66+ # init new thread created outside of autograd
67+ # TODO: what if context manager wrapped outside of thread?
68+ setattr (self ._local , self .tls_key , CompiledAutogradTLS ())
69+ state = getattr (self ._local , self .tls_key )
70+ torch ._C ._stash_obj_in_tls (self .tls_key , state )
71+ return state
72+
73+ # queries on the object stored in TLS
74+ def get (self , name ):
75+ return getattr (self ._get_tls (), name )
76+
77+ def set_tls (self , ** kwargs ) -> Callable [[], None ]:
78+ priors : Dict [str , Any ] = {}
79+ for k , v in kwargs .items ():
80+ state = self ._get_tls ()
81+ priors [k ] = getattr (state , k )
82+ setattr (state , k , v )
83+
84+ torch ._C ._dynamo .compiled_autograd .notify_autograd_engine ()
85+
86+ def revert ():
87+ self .set_tls (** priors )
88+
89+ return revert
90+
91+ def enabled (self ) -> bool :
92+ return self .get ("compiler" ) is not None
93+
94+ def enter_ctx (self ) -> Callable [[], None ]:
95+ state = self ._get_tls ()
96+ state .next_ctx_id += 1
97+ id = state .next_ctx_id
98+
99+ def exit ():
100+ assert (
101+ state is self ._get_tls ()
102+ ), "Runtime must begin and end on the same thread"
103+ assert state .next_ctx_id == id , (
104+ "Error nesting compiled autograd context managers: "
105+ "inner context managers must have shorter lifetime than the outer context manager"
106+ )
107+ state .next_ctx_id -= 1
108+
109+ return exit
110+
111+ def enter_compiled_region (self ) -> Callable [[], None ]:
112+ state = self ._get_tls ()
113+ prior = state .in_compiled_autograd_region
114+ state .in_compiled_autograd_region = True
115+ assert prior is False , "Nested compiled autograd regions are not supported"
116+
117+ def exit ():
118+ assert (
119+ state is self ._get_tls ()
120+ ), "Runtime must begin and end on the same thread"
121+ assert state .in_compiled_autograd_region is True
122+ state .in_compiled_autograd_region = prior
123+
124+ return exit
45125
46126
47- def snapshot_cudagraph_enabled ():
48- return torch ._inductor .config .triton .cudagraphs
127+ local = TLSWrapper ()
49128
50129
51130def maybe_clone (x ):
@@ -307,7 +386,7 @@ def end_capture(self, outputs):
307386 self .rename_aot_dispatcher_nodes ()
308387 self .reorder_accumulate_grad_nodes ()
309388 runtime_inputs_to_move : List [int ] = []
310- if snapshot_cudagraph_enabled () :
389+ if torch . _inductor . config . triton . cudagraphs :
311390 runtime_inputs_to_move = self .move_graph_nodes_to_cuda (self .fx_tracer .graph )
312391
313392 graph = GraphModule (
@@ -329,16 +408,15 @@ def end_capture(self, outputs):
329408 )
330409
331410 def runtime_wrapper (compiled_fn , inputs , sizes , scalars , hooks ):
332- global in_compiled_autograd_region
333411 try :
334- in_compiled_autograd_region = True
412+ exit_compiled_region = local . enter_compiled_region ()
335413 for i in runtime_inputs_to_move :
336414 inputs [i ] = inputs [i ].pin_memory ().cuda (non_blocking = True )
337415
338416 with disable ():
339417 return compiled_fn (inputs , sizes , scalars , hooks )
340418 finally :
341- in_compiled_autograd_region = False
419+ exit_compiled_region ()
342420
343421 return runtime_wrapper , self .compiler_fn (graph )
344422
@@ -510,15 +588,9 @@ def set_node_origin(
510588 set_stack_trace (new_stack_trace )
511589
512590
513- # state of the autograd engine dispatch, kept in sync by enable/disable context managers
514- compiled_autograd_enabled = False
515-
516591# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
517592compiled_autograd_enabled_force_eager = False
518593
519- # global flag to check if we are processing graphs produced from a compiled autograd graph
520- in_compiled_autograd_region = False
521-
522594
523595@contextlib .contextmanager
524596def enable (compiler_fn ):
@@ -538,39 +610,42 @@ def enable(compiler_fn):
538610 # we need to lazily import it, because of circular dependencies
539611 import torch ._inductor .cudagraph_trees
540612
541- prior = torch ._C ._dynamo .compiled_autograd .set_autograd_compiler (
542- functools .partial (AutogradCompilerInstance , compiler_fn )
613+ exit_ctx = local .enter_ctx ()
614+ revert_tls = local .set_tls (
615+ compiler = functools .partial (AutogradCompilerInstance , compiler_fn ),
616+ vlogger = verbose_log
617+ if torch ._logging ._internal .log_state .is_artifact_enabled (
618+ "compiled_autograd_verbose"
619+ )
620+ else None ,
543621 )
544- if snapshot_verbose_logging_enabled ():
545- torch ._C ._dynamo .compiled_autograd .set_verbose_logger (verbose_log )
546- global compiled_autograd_enabled
547- compiled_autograd_enabled = True
548622 try :
549623 with torch .autograd .set_multithreading_enabled (False ):
550624 yield
551625 finally :
552- if not prior :
553- compiled_autograd_enabled = False
554- torch ._C ._dynamo .compiled_autograd .set_autograd_compiler (prior )
626+ revert_tls ()
627+ exit_ctx ()
555628
556629
557630@contextlib .contextmanager
558631def disable ():
559- prior = torch ._C ._dynamo .compiled_autograd .set_autograd_compiler (None )
560- global compiled_autograd_enabled
561- compiled_autograd_enabled = False
632+ exit_ctx = local .enter_ctx ()
633+ revert_tls = local .set_tls (
634+ compiler = None ,
635+ vlogger = None ,
636+ )
562637 try :
563638 yield
564639 finally :
565- if prior :
566- compiled_autograd_enabled = True
567- torch ._C ._dynamo .compiled_autograd .set_autograd_compiler (prior )
640+ revert_tls ()
641+ exit_ctx ()
568642
569643
570644# return to starting state of a new process
571645def reset () -> None :
572- global compiled_autograd_enabled
573- compiled_autograd_enabled = False
574- assert not in_compiled_autograd_region
575- torch ._C ._dynamo .compiled_autograd .set_autograd_compiler (None )
576- torch ._C ._dynamo .compiled_autograd .set_verbose_logger (None )
646+ assert local .get ("next_ctx_id" ) == 0
647+ assert local .get ("in_compiled_autograd_region" ) is False
648+ local .set_tls (
649+ compiler = None ,
650+ vlogger = None ,
651+ )
0 commit comments