@@ -58,215 +58,6 @@ class ExportDynamoConfig:
5858 allow_rnn : bool = True
5959
6060
61- # We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
62- # is called multiple times.
63- @lru_cache
64- def capture_pre_autograd_graph_warning ():
65- from torch ._inductor import config
66-
67- log .warning ("+============================+" )
68- log .warning ("| !!! WARNING !!! |" )
69- log .warning ("+============================+" )
70- log .warning ("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward." )
71- log .warning ("Please switch to use torch.export.export_for_training instead." )
72- if config .is_fbcode ():
73- log .warning ("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training." ) # noqa: B950
74-
75- @lru_cache
76- def print_export_warning ():
77- log .warning ("Using torch.export.export_for_training(...,strict=True)" )
78-
79- def gm_using_training_ir (graph_module : torch .fx .GraphModule ) -> bool :
80- """
81- Returns true if the graph module is detected to use training IR.
82-
83- This function checks for two specific conditions within the nodes of the graph module:
84- 1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
85- 2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
86-
87- The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
88- """
89- # TODO: clean up this code after training IR migration.
90- # T199018392
91- has_training_ir_batch_norm = False
92- has_deprecated_ir_tag = getattr (graph_module , "capture_pre_autograd_graph_tag" , False )
93- for node in graph_module .graph .nodes :
94- if node .op == "call_function" :
95- if node .target == torch .ops .aten .batch_norm .default :
96- has_training_ir_batch_norm = True
97- if node .meta .get ("capture_pre_autograd_graph_tag" , False ):
98- has_deprecated_ir_tag = True
99- if node .target in [
100- torch .ops .aten ._native_batch_norm_legit .default ,
101- torch .ops .aten .cudnn_batch_norm .default ,
102- torch .ops .aten .miopen_batch_norm .default ,
103- ]:
104- has_deprecated_ir_tag = True
105-
106- if has_deprecated_ir_tag and has_training_ir_batch_norm :
107- raise RuntimeError ("Conflicting IR detected." )
108- return has_training_ir_batch_norm or not has_deprecated_ir_tag
109-
110- @compatibility (is_backward_compatible = False )
111- def capture_pre_autograd_graph (
112- f : torch .nn .Module ,
113- args : Tuple [Any ],
114- kwargs : Optional [Dict [str , Any ]] = None ,
115- dynamic_shapes : Optional [Union [Dict [str , Any ], Tuple [Any ]]] = None ,
116- ) -> torch .nn .Module :
117- """
118- A helper function that is intended to trace a module before any pre-autograd
119- decomposition is run. The produced module will be "non-functional" and
120- composed of aten operators. Later this API will be deleted in favor of more general
121- torch.export API.
122-
123- Args:
124- f: nn.Module to be traced
125-
126- args: example positional inputs.
127-
128- kwargs: optional example keyword inputs.
129-
130- dynamic_shapes: Should either be:
131- 1) a dict from argument names of ``f`` to their dynamic shape specifications,
132- 2) a tuple that specifies dynamic shape specifications for each input in original order.
133- If you are specifying dynamism on keyword args, you will need to pass them in the order that
134- is defined in the original function signature.
135-
136- The dynamic shape of a tensor argument can be specified as either
137- (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
138- not required to include static dimension indices in this dict, but when they are,
139- they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
140- where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
141- are denoted by None. Arguments that are dicts or tuples / lists of tensors are
142- recursively specified by using mappings or sequences of contained specifications.
143-
144- Returns:
145- An nn.Module containing the traced method.
146-
147- """
148- from torch .export ._trace import _extract_fake_inputs , DEFAULT_EXPORT_DYNAMO_CONFIG , _ignore_backend_decomps
149- from torch ._utils_internal import capture_pre_autograd_graph_using_training_ir
150- from torch ._export .non_strict_utils import make_constraints
151- from torch ._subclasses .functional_tensor import FunctionalTensor
152- from torch .export ._unlift import _create_stateful_graph_module
153- from torch .export .dynamic_shapes import _combine_args
154-
155- capture_pre_autograd_graph_warning ()
156-
157- if sys .platform == "win32" :
158- raise RuntimeError ("capture_pre_autograd_graph not yet supported on Windows" )
159-
160- assert isinstance (f , torch .nn .Module ), "Expected an nn.Module instance."
161-
162- if kwargs is None :
163- kwargs = {}
164-
165- if capture_pre_autograd_graph_using_training_ir ():
166- print_export_warning ()
167- module = torch .export .export_for_training (f , args , kwargs , dynamic_shapes = dynamic_shapes , strict = True ).module ()
168- else :
169- log_export_usage (event = "export.private_api" , flags = {"capture_pre_autograd_graph" })
170-
171- # Do not decompose dropout for exported models, because in eval mode the dropout
172- # op disappears from the graph, which makes it difficult to switch to train mode.
173- # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
174-
175- # We force create native_batch_norm because the below materialization logic
176- # only applies to CIA ops.
177- maybe_aliasing_or_mutating_ops = [torch .ops .aten .native_batch_norm .default ]
178-
179- _materialize_cpp_cia_ops ()
180-
181- for op in torch .ops .aten :
182- op_obj = getattr (torch .ops .aten , op )
183- for overload in op_obj .overloads ():
184- op_overload = getattr (op_obj , overload )
185- if torch .Tag .maybe_aliasing_or_mutating in op_overload .tags :
186- maybe_aliasing_or_mutating_ops .append (op_overload )
187-
188- decomp_table = {
189- op : op .decompose
190- for op in maybe_aliasing_or_mutating_ops
191- if op != torch .ops .aten .dropout .default
192- }
193- with torch ._dynamo .config .patch (dataclasses .asdict (DEFAULT_EXPORT_DYNAMO_CONFIG )), _ignore_backend_decomps ():
194- m = torch ._dynamo .export (
195- f ,
196- dynamic_shapes = dynamic_shapes ,
197- assume_static_by_default = True ,
198- tracing_mode = "symbolic" ,
199- decomposition_table = decomp_table ,
200- pre_dispatch = True ,
201- aten_graph = True ,
202- _log_export_usage = False ,
203- )(
204- * args ,
205- ** kwargs ,
206- )[0 ]
207-
208- _ , _ , fake_mode = _extract_fake_inputs (m , args , kwargs )
209-
210- m .meta ["inline_constraints" ] = {
211- k : v
212- for k , v in fake_mode .shape_env .var_to_range .items ()
213- if re .match (r"^[if]\d+$" , str (k ))
214- }
215-
216- if isinstance (f , torch .nn .Module ):
217- from torch .export ._trace import _restore_state_dict
218- _restore_state_dict (f , m )
219-
220- combined_args = _combine_args (f , args , kwargs )
221- range_constraints = make_constraints (
222- fake_mode ,
223- m ,
224- combined_args ,
225- dynamic_shapes ,
226- 0 ,
227- )
228-
229- module = _create_stateful_graph_module (
230- m ,
231- range_constraints = range_constraints ,
232- )
233-
234- setattr (module , "capture_pre_autograd_graph_tag" , True ) # noqa: B010
235- for node in module .graph .nodes :
236- node .meta ["capture_pre_autograd_graph_tag" ] = True
237-
238- error_message = \
239- """
240- Calling train() or eval() is not supported for exported models.
241- Alternatively, you may override these methods to do custom user behavior as follows:
242-
243- def _my_train(self, mode: bool = True):
244- ...
245-
246- def _my_eval(self):
247- ...
248-
249- model.train = types.MethodType(_my_train, model)
250- model.eval = types.MethodType(_my_eval, model)
251- """
252-
253- def _train (self , mode : bool = True ):
254- raise NotImplementedError (error_message )
255-
256- def _eval (self , mode : bool = True ):
257- raise NotImplementedError (error_message )
258-
259- module .train = types .MethodType (_train , module ) # type: ignore[method-assign]
260- module .eval = types .MethodType (_eval , module ) # type: ignore[method-assign]
261-
262- # Remove Proxy because they cannot be deepcopied or pickled.
263- if hasattr (module , "_buffers" ):
264- torch ._export .utils .remove_proxy_from_state_dict (
265- module ._buffers , in_place = True
266- )
267- return module
268-
269-
27061# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
27162# is called multiple times.
27263@lru_cache
0 commit comments