Skip to content

Commit d8ea4ce

Browse files
yushangdipytorchmergebot
authored andcommitted
[reland] Kill capture_pre_autograd_graph API (pytorch#143426)
Summary: Delete the following API: - capture_pre_autograd_graph() - capture_pre_autograd_graph_using_training_ir() - gm_using_training_ir() Update XLA pin to include pytorch/xla#8398 There's no more call sites to `capture_pre_autograd_graph`. Except 1) two test cases in coreml, guarded by version guard, PR to remove: apple/coremltools#2400 2) a few call sites guarded by version guard (< 2.5.0) Test Plan: CI Differential Revision: D67354440 Pull Request resolved: pytorch#143426 Approved by: https://github.com/gmagogsfm
1 parent eb67dd3 commit d8ea4ce

File tree

8 files changed

+6
-250
lines changed

8 files changed

+6
-250
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
73f54ba5bd7fb83d7ba81fe6f5e05fb6ee815d6f
1+
b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b

torch/_export/__init__.py

Lines changed: 0 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -58,215 +58,6 @@ class ExportDynamoConfig:
5858
allow_rnn: bool = True
5959

6060

61-
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
62-
# is called multiple times.
63-
@lru_cache
64-
def capture_pre_autograd_graph_warning():
65-
from torch._inductor import config
66-
67-
log.warning("+============================+")
68-
log.warning("| !!! WARNING !!! |")
69-
log.warning("+============================+")
70-
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
71-
log.warning("Please switch to use torch.export.export_for_training instead.")
72-
if config.is_fbcode():
73-
log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
74-
75-
@lru_cache
76-
def print_export_warning():
77-
log.warning("Using torch.export.export_for_training(...,strict=True)")
78-
79-
def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool:
80-
"""
81-
Returns true if the graph module is detected to use training IR.
82-
83-
This function checks for two specific conditions within the nodes of the graph module:
84-
1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
85-
2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
86-
87-
The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
88-
"""
89-
# TODO: clean up this code after training IR migration.
90-
# T199018392
91-
has_training_ir_batch_norm = False
92-
has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
93-
for node in graph_module.graph.nodes:
94-
if node.op == "call_function":
95-
if node.target == torch.ops.aten.batch_norm.default:
96-
has_training_ir_batch_norm = True
97-
if node.meta.get("capture_pre_autograd_graph_tag", False):
98-
has_deprecated_ir_tag = True
99-
if node.target in [
100-
torch.ops.aten._native_batch_norm_legit.default,
101-
torch.ops.aten.cudnn_batch_norm.default,
102-
torch.ops.aten.miopen_batch_norm.default,
103-
]:
104-
has_deprecated_ir_tag = True
105-
106-
if has_deprecated_ir_tag and has_training_ir_batch_norm:
107-
raise RuntimeError("Conflicting IR detected.")
108-
return has_training_ir_batch_norm or not has_deprecated_ir_tag
109-
110-
@compatibility(is_backward_compatible=False)
111-
def capture_pre_autograd_graph(
112-
f: torch.nn.Module,
113-
args: Tuple[Any],
114-
kwargs: Optional[Dict[str, Any]] = None,
115-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
116-
) -> torch.nn.Module:
117-
"""
118-
A helper function that is intended to trace a module before any pre-autograd
119-
decomposition is run. The produced module will be "non-functional" and
120-
composed of aten operators. Later this API will be deleted in favor of more general
121-
torch.export API.
122-
123-
Args:
124-
f: nn.Module to be traced
125-
126-
args: example positional inputs.
127-
128-
kwargs: optional example keyword inputs.
129-
130-
dynamic_shapes: Should either be:
131-
1) a dict from argument names of ``f`` to their dynamic shape specifications,
132-
2) a tuple that specifies dynamic shape specifications for each input in original order.
133-
If you are specifying dynamism on keyword args, you will need to pass them in the order that
134-
is defined in the original function signature.
135-
136-
The dynamic shape of a tensor argument can be specified as either
137-
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
138-
not required to include static dimension indices in this dict, but when they are,
139-
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
140-
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
141-
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
142-
recursively specified by using mappings or sequences of contained specifications.
143-
144-
Returns:
145-
An nn.Module containing the traced method.
146-
147-
"""
148-
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
149-
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
150-
from torch._export.non_strict_utils import make_constraints
151-
from torch._subclasses.functional_tensor import FunctionalTensor
152-
from torch.export._unlift import _create_stateful_graph_module
153-
from torch.export.dynamic_shapes import _combine_args
154-
155-
capture_pre_autograd_graph_warning()
156-
157-
if sys.platform == "win32":
158-
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
159-
160-
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
161-
162-
if kwargs is None:
163-
kwargs = {}
164-
165-
if capture_pre_autograd_graph_using_training_ir():
166-
print_export_warning()
167-
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
168-
else:
169-
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
170-
171-
# Do not decompose dropout for exported models, because in eval mode the dropout
172-
# op disappears from the graph, which makes it difficult to switch to train mode.
173-
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
174-
175-
# We force create native_batch_norm because the below materialization logic
176-
# only applies to CIA ops.
177-
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
178-
179-
_materialize_cpp_cia_ops()
180-
181-
for op in torch.ops.aten:
182-
op_obj = getattr(torch.ops.aten, op)
183-
for overload in op_obj.overloads():
184-
op_overload = getattr(op_obj, overload)
185-
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
186-
maybe_aliasing_or_mutating_ops.append(op_overload)
187-
188-
decomp_table = {
189-
op: op.decompose
190-
for op in maybe_aliasing_or_mutating_ops
191-
if op != torch.ops.aten.dropout.default
192-
}
193-
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
194-
m = torch._dynamo.export(
195-
f,
196-
dynamic_shapes=dynamic_shapes,
197-
assume_static_by_default=True,
198-
tracing_mode="symbolic",
199-
decomposition_table=decomp_table,
200-
pre_dispatch=True,
201-
aten_graph=True,
202-
_log_export_usage=False,
203-
)(
204-
*args,
205-
**kwargs,
206-
)[0]
207-
208-
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
209-
210-
m.meta["inline_constraints"] = {
211-
k: v
212-
for k, v in fake_mode.shape_env.var_to_range.items()
213-
if re.match(r"^[if]\d+$", str(k))
214-
}
215-
216-
if isinstance(f, torch.nn.Module):
217-
from torch.export._trace import _restore_state_dict
218-
_restore_state_dict(f, m)
219-
220-
combined_args = _combine_args(f, args, kwargs)
221-
range_constraints = make_constraints(
222-
fake_mode,
223-
m,
224-
combined_args,
225-
dynamic_shapes,
226-
0,
227-
)
228-
229-
module = _create_stateful_graph_module(
230-
m,
231-
range_constraints=range_constraints,
232-
)
233-
234-
setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010
235-
for node in module.graph.nodes:
236-
node.meta["capture_pre_autograd_graph_tag"] = True
237-
238-
error_message = \
239-
"""
240-
Calling train() or eval() is not supported for exported models.
241-
Alternatively, you may override these methods to do custom user behavior as follows:
242-
243-
def _my_train(self, mode: bool = True):
244-
...
245-
246-
def _my_eval(self):
247-
...
248-
249-
model.train = types.MethodType(_my_train, model)
250-
model.eval = types.MethodType(_my_eval, model)
251-
"""
252-
253-
def _train(self, mode: bool = True):
254-
raise NotImplementedError(error_message)
255-
256-
def _eval(self, mode: bool = True):
257-
raise NotImplementedError(error_message)
258-
259-
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
260-
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
261-
262-
# Remove Proxy because they cannot be deepcopied or pickled.
263-
if hasattr(module, "_buffers"):
264-
torch._export.utils.remove_proxy_from_state_dict(
265-
module._buffers, in_place=True
266-
)
267-
return module
268-
269-
27061
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
27162
# is called multiple times.
27263
@lru_cache

torch/_utils_internal.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,6 @@ def log_torch_jit_trace_exportability(
167167
return
168168

169169

170-
def capture_pre_autograd_graph_using_training_ir() -> bool:
171-
return False
172-
173-
174170
def justknobs_check(name: str, default: bool = True) -> bool:
175171
"""
176172
This function can be used to killswitch functionality in FB prod,

torch/ao/quantization/pt2e/export_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
5555
m.graph.eliminate_dead_code()
5656
m.recompile()
5757

58-
from torch._export import gm_using_training_ir
59-
60-
using_training_ir = gm_using_training_ir(m)
61-
6258
for inplace in [False, True]:
6359

6460
def dropout_train(x):
@@ -72,23 +68,19 @@ def dropout_eval(x):
7268
match_pattern = _get_aten_graph_module_for_pattern(
7369
_WrapperModule(dropout_train),
7470
example_inputs,
75-
using_training_ir=using_training_ir,
7671
)
7772
replacement_pattern = _get_aten_graph_module_for_pattern(
7873
_WrapperModule(dropout_eval),
7974
example_inputs,
80-
using_training_ir=using_training_ir,
8175
)
8276
else:
8377
match_pattern = _get_aten_graph_module_for_pattern(
8478
_WrapperModule(dropout_eval),
8579
example_inputs,
86-
using_training_ir=using_training_ir,
8780
)
8881
replacement_pattern = _get_aten_graph_module_for_pattern(
8982
_WrapperModule(dropout_train),
9083
example_inputs,
91-
using_training_ir=using_training_ir,
9284
)
9385

9486
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
@@ -122,10 +114,6 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
122114
m.graph.eliminate_dead_code()
123115
m.recompile()
124116

125-
from torch._export import gm_using_training_ir
126-
127-
using_training_ir = gm_using_training_ir(m)
128-
129117
def bn_train(
130118
x: torch.Tensor,
131119
bn_weight: torch.Tensor,
@@ -162,13 +150,11 @@ def bn_eval(
162150
_WrapperModule(bn_train),
163151
example_inputs,
164152
is_cuda,
165-
using_training_ir=using_training_ir,
166153
)
167154
bn_eval_aten = _get_aten_graph_module_for_pattern(
168155
_WrapperModule(bn_eval),
169156
example_inputs,
170157
is_cuda,
171-
using_training_ir=using_training_ir,
172158
)
173159

174160
if train_to_eval:

torch/ao/quantization/pt2e/qat_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -667,16 +667,11 @@ def _fuse_conv_bn_qat_helper(
667667
m.graph.eliminate_dead_code()
668668
m.recompile()
669669

670-
from torch._export import gm_using_training_ir
671-
672-
using_training_ir = gm_using_training_ir(m)
673-
674670
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
675671
match_pattern = _get_aten_graph_module_for_pattern(
676672
conv_bn_pattern,
677673
example_inputs,
678674
is_cuda,
679-
using_training_ir=using_training_ir,
680675
)
681676

682677
# Step (1): Replace patterns with conv bias
@@ -690,7 +685,6 @@ def _fuse_conv_bn_qat_helper(
690685
qat_conv_bn_pattern,
691686
example_inputs,
692687
is_cuda,
693-
using_training_ir=using_training_ir,
694688
)
695689
replacements_with_conv_bias = replace_pattern_with_filters(
696690
m,
@@ -708,7 +702,6 @@ def _fuse_conv_bn_qat_helper(
708702
qat_conv_bn_pattern_no_conv_bias,
709703
example_inputs,
710704
is_cuda,
711-
using_training_ir=using_training_ir,
712705
)
713706
replacements_no_conv_bias = replace_pattern_with_filters(
714707
m,
@@ -922,9 +915,6 @@ def _fold_conv_bn_qat_helper(
922915
"""
923916
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
924917
"""
925-
from torch._export import gm_using_training_ir
926-
927-
using_training_ir = gm_using_training_ir(m)
928918

929919
m.graph.eliminate_dead_code()
930920
m.recompile()
@@ -958,7 +948,6 @@ def _fold_conv_bn_qat_helper(
958948
match_pattern,
959949
example_inputs,
960950
is_cuda,
961-
using_training_ir=using_training_ir,
962951
**kwargs,
963952
)
964953
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
@@ -968,7 +957,6 @@ def _fold_conv_bn_qat_helper(
968957
replacement_pattern,
969958
example_inputs,
970959
is_cuda,
971-
using_training_ir=using_training_ir,
972960
**kwargs,
973961
)
974962
replacements.extend(

torch/ao/quantization/pt2e/representation/rewrite.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -797,19 +797,16 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
797797
]
798798

799799
remove_tensor_overload_for_qdq_ops(model)
800-
from torch._export import gm_using_training_ir
801-
802-
using_training_ir = gm_using_training_ir(model)
803800

804801
for rewrite_info in _REWRITE_INFO_LIST:
805802
example_inputs = rewrite_info.example_inputs
806803
pattern = rewrite_info.pattern
807804
replacement = rewrite_info.replacement
808805
pattern_post_trans = rewrite_info.pattern_post_trans
809806
replacement_post_trans = rewrite_info.replacement_post_trans
810-
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
807+
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
811808
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
812-
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
809+
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
813810
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
814811
if pattern_post_trans:
815812
pattern = pattern_post_trans(pattern)

torch/ao/quantization/pt2e/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def _get_aten_graph_module_for_pattern(
351351
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
352352
)
353353

354+
# T199018392
355+
# TODO: remove the using_training_ir flag from function
354356
if using_training_ir:
355357
aten_pattern = torch.export.export_for_training(
356358
pattern, # type: ignore[arg-type]

torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,6 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
530530
gm.graph.eliminate_dead_code()
531531
gm.recompile()
532532

533-
from torch._export import gm_using_training_ir
534-
535-
using_training_ir = gm_using_training_ir(gm)
536-
537533
matches = []
538534
if is_conv_transpose:
539535
combinations = [
@@ -556,7 +552,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
556552
# Match against all conv dimensions and cuda variants
557553
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
558554
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
559-
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda, using_training_ir=using_training_ir) # type: ignore[has-type]
555+
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
560556
pattern.graph.eliminate_dead_code()
561557
pattern.recompile()
562558
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)

0 commit comments

Comments
 (0)