Skip to content

Commit c1f51cf

Browse files
ydwu4pytorchmergebot
authored andcommitted
[map] defer importing AOTConfig and create_joint dependency (pytorch#151479)
Summary: We reverted D72896450 due to a weird error happens at a seemingly unrelated test "buck2 run apf/data/tests:preproc_state_serializer_test -- --filter-text "test_load_artifact" " I did some investigation and found that moving import AOTConfig and create_joint inside the create_fw_bw_grap causes a delay of importing the recursively imported modules in AOTConfig create_joint from test construction time to the test running time. The path.exists mock gets called multiple times due to the inspect.getsource calls in multiple places of torch. Specifically, we set a breakpoint at the sideeffect of mocked os.path.exists. P1787425831 shows the importing stack trace before the change. P1787431638 shows the importing stacktrace after the change. The notable difference is that in the second pastry, we trigger an os.path.exists when somewhere in triton we called inspect.getsourcelines when we construct OnDiskPreprocStateSerializer, which gets recorded by the mock. Looking at the test, it seems what the test actualy wants to test is the deserialize step. So we reset_mock before the step to avoid mocking things happened at import time. Test Plan: buck2 run apf/data/tests:preproc_state_serializer_test -- --filter-text "test_load_artifact" and existing tests for map. Differential Revision: D73138415 Pull Request resolved: pytorch#151479 Approved by: https://github.com/angelayi, https://github.com/zou3519
1 parent 99ae7d4 commit c1f51cf

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

torch/_higher_order_ops/map.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch.utils._pytree as pytree
44
from torch._C import DispatchKey
55
from torch._dispatch.python import suspend_functionalization
6-
from torch._functorch.aot_autograd import AOTConfig, create_joint
76
from torch._higher_order_ops.utils import (
87
_has_potential_branch_input_alias,
98
_has_potential_branch_input_mutation,
@@ -54,16 +53,6 @@ def __call__(self, *args, **kwargs):
5453

5554
map_impl = MapImpl()
5655

57-
dummy_aot_config = AOTConfig(
58-
fw_compiler=None, # type: ignore[arg-type]
59-
bw_compiler=None, # type: ignore[arg-type]
60-
partition_fn=None, # type: ignore[arg-type]
61-
decompositions={},
62-
num_params_buffers=0,
63-
aot_id=0,
64-
keep_inference_input_mutations=False,
65-
)
66-
6756

6857
def create_fw_bw_graph(f, num_mapped_args, *args):
6958
mapped_xs = args[:num_mapped_args]
@@ -96,6 +85,18 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
9685

9786
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
9887

88+
from torch._functorch.aot_autograd import AOTConfig, create_joint
89+
90+
dummy_aot_config = AOTConfig(
91+
fw_compiler=None, # type: ignore[arg-type]
92+
bw_compiler=None, # type: ignore[arg-type]
93+
partition_fn=None, # type: ignore[arg-type]
94+
decompositions={},
95+
num_params_buffers=0,
96+
aot_id=0,
97+
keep_inference_input_mutations=False,
98+
)
99+
99100
def joint_f(*example_args):
100101
joint_mapped_args = example_args[:joint_num_mapped]
101102
args = example_args[joint_num_mapped:]

0 commit comments

Comments
 (0)