Skip to content

Commit 3685c63

Browse files
ppanchaliapytorchmergebot
authored andcommitted
[pytorch] Plumb compile context from dynamo.export to aot_compile (pytorch#138793)
Summary: tlparse shows unknown for certain items when _export.aot_compile() passes the graph obtained from dynamo.export() to inductor.aot_compile(), we also do not have access to the dynamo trace in the GraphModule exported by dynamo. This change plumbs through the compile_context into aot_compile as a part of GraphModule.meta without a major change to APIs within dynamo. Addresses issue: pytorch#123759 Test Plan: ``` buck2 test mode/opt //caffe2/test/dynamo:test_dynamo Buck UI: https://www.internalfb.com/buck2/ad64c267-65be-47cf-a94f-e4b26e6e030b Test UI: https://www.internalfb.com/intern/testinfra/testrun/9288674286334710 Network: Up: 83KiB Down: 314KiB (reSessionID-1dad223b-c91d-4718-97a4-bb2c81e480f0) Jobs completed: 10750. Time elapsed: 19:18.5s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 5365. Fail 2. Fatal 0. Skip 4. Build failure 0 buck2 test mode/opt //caffe2/test/dynamo:test_dynamo_fb Buck UI: https://www.internalfb.com/buck2/179a60bb-34e1-43b3-97ad-91af8a93ab01 Test UI: https://www.internalfb.com/intern/testinfra/testrun/2533275046340687 Network: Up: 201KiB Down: 1.8GiB (reSessionID-36f33983-6d78-4ec9-aa1b-34cee80dcb4f) Jobs completed: 17. Time elapsed: 42.9s. Cache hits: 0%. Commands: 1 (cached: 0, remote: 0, local: 1) Tests finished: Pass 6. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxZGXf6/index.html Repor fixed: pytorch#123759 Differential Revision: D64863946 Pull Request resolved: pytorch#138793 Approved by: https://github.com/ezyang
1 parent 91ded05 commit 3685c63

File tree

5 files changed

+40
-20
lines changed

5 files changed

+40
-20
lines changed

torch/_dynamo/convert_frame.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -543,23 +543,24 @@ def __call__(
543543
info = f"{code.co_name} {code.co_filename}:{code.co_firstlineno}"
544544
dynamo_tls.traced_frame_infos.append(info)
545545

546-
return _compile(
547-
frame.f_code,
548-
frame.f_globals,
549-
frame.f_locals,
550-
frame.f_builtins,
551-
self._torchdynamo_orig_callable,
552-
self._one_graph,
553-
self._export,
554-
self._export_constraints,
555-
hooks,
556-
cache_entry,
557-
cache_size,
558-
frame,
559-
frame_state=frame_state,
560-
compile_id=compile_id,
561-
skip=skip + 1,
562-
)
546+
with compile_context(CompileContext(compile_id)):
547+
return _compile(
548+
frame.f_code,
549+
frame.f_globals,
550+
frame.f_locals,
551+
frame.f_builtins,
552+
self._torchdynamo_orig_callable,
553+
self._one_graph,
554+
self._export,
555+
self._export_constraints,
556+
hooks,
557+
cache_entry,
558+
cache_size,
559+
frame,
560+
frame_state=frame_state,
561+
compile_id=compile_id,
562+
skip=skip + 1,
563+
)
563564

564565

565566
def convert_frame_assert(

torch/_dynamo/eval_frame.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,8 @@ def transform(self):
10711071
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
10721072
"dynamo_flat_name_to_original_fqn"
10731073
]
1074+
if "dynamo_compile_id" in self.module.meta:
1075+
result_gm.meta["dynamo_compile_id"] = self.module.meta["dynamo_compile_id"]
10741076
return result_gm
10751077

10761078

torch/_dynamo/output_graph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
import torch.nn
2424
import torch.utils._pytree as pytree
2525
from torch import fx
26-
from torch._guards import GlobalContextCheckpointState, Source, TracingContext
26+
from torch._guards import (
27+
CompileContext,
28+
CompileId,
29+
GlobalContextCheckpointState,
30+
Source,
31+
TracingContext,
32+
)
2733
from torch._utils_internal import signpost_event
2834
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
2935
from torch.fx.experimental._backward_state import BackwardState
@@ -313,6 +319,9 @@ def __init__(
313319
export=self.export,
314320
)
315321
self.tracing_context: TracingContext = TracingContext(fake_mode)
322+
self.dynamo_compile_id: Optional[
323+
CompileId
324+
] = CompileContext.current_compile_id()
316325
self.init_ambient_guards()
317326

318327
# Map each tensor id to a list of sources. This is necessary because
@@ -1368,6 +1377,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
13681377
gm.meta[
13691378
"dynamo_flat_name_to_original_fqn"
13701379
] = self.dynamo_flat_name_to_original_fqn.copy()
1380+
gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
13711381

13721382
graph_code_log.debug(
13731383
"%s",

torch/_export/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.utils._pytree as pytree
2525

2626
from torch._dispatch.python import enable_python_dispatcher
27+
from torch._guards import compile_context
2728
from torch._utils_internal import log_export_usage
2829
from torch.export._tree_utils import reorder_kwargs
2930
from torch.export.graph_signature import (
@@ -40,7 +41,6 @@
4041
from torch.fx import traceback as fx_traceback
4142
from torch.fx._compatibility import compatibility
4243
from torch.fx.experimental.proxy_tensor import make_fx
43-
from torch._subclasses.fake_tensor import unset_fake_temporarily
4444
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
4545

4646
from .wrappers import _wrap_submodules

torch/_inductor/compile_fx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,11 @@ def compile_fx_aot(
12281228
}
12291229

12301230
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
1231-
with V.set_aot_compilation(True):
1231+
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
1232+
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
1233+
with V.set_aot_compilation(True), torch._guards.compile_context(
1234+
saved_compile_context
1235+
):
12321236
compiled_lib_path = compile_fx(
12331237
model_,
12341238
example_inputs_,
@@ -1665,6 +1669,9 @@ def bw_compiler(
16651669
"dynamo_flat_name_to_original_fqn"
16661670
]
16671671

1672+
if "dynamo_compile_id" in model_.meta:
1673+
unlifted_gm.meta["dynamo_compile_id"] = model_.meta["dynamo_compile_id"]
1674+
16681675
# Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515)
16691676
# In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into
16701677
# _sfdp_init() to register patterns.

0 commit comments

Comments
 (0)