Skip to content

Commit 72e8f34

Browse files
yushangdipytorchmergebot
authored andcommitted
[AoTI Minifier] UX Improvement (pytorch#143330)
Summary: - When a user specify `TORCHINDUCTOR_MAX_AUTOTUNE=1` env variable, we add `config.max_autotune=True` to the generated minifier_launcher - We should do this to other inductor configs as well in a followup Diff Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable. This is: 1) not convenient for the users 2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error. Underlying implementation change: - Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config ``` Differential Revision: D67299312 Pull Request resolved: pytorch#143330 Approved by: https://github.com/jansel, https://github.com/eellison
1 parent 096cb87 commit 72e8f34

File tree

8 files changed

+78
-9
lines changed

8 files changed

+78
-9
lines changed

test/dynamo/test_debug_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import os
34
import unittest
5+
from unittest.mock import patch
46

57
import torch
68
from functorch import make_fx
79
from torch._dynamo import debug_utils
8-
from torch._dynamo.debug_utils import aot_graph_input_parser
10+
from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string
911
from torch._dynamo.test_case import TestCase
1012
from torch.testing._internal.inductor_utils import HAS_CUDA
1113

@@ -172,6 +174,25 @@ def forward(
172174
self.assertEqual(list(kwargs["primals_4"].shape), [5])
173175
self.assertEqual(kwargs["primals_5"], 5)
174176

177+
@patch.dict(os.environ, {"TORCHINDUCTOR_MAX_AUTOTUNE": "1", "TEST_ENV": "1"})
178+
def test_generate_env_vars_string(self):
179+
env_strings = generate_env_vars_string()
180+
self.assertIn(
181+
"""os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
182+
""",
183+
env_strings,
184+
)
185+
self.assertIn(
186+
"""import os
187+
""",
188+
env_strings,
189+
)
190+
self.assertNotIn(
191+
"""TEST_ENV
192+
""",
193+
env_strings,
194+
)
195+
175196

176197
if __name__ == "__main__":
177198
from torch._dynamo.test_case import run_tests

test/test_utils_config_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def test_codegen_config(self):
173173
self.assertEqual(
174174
code,
175175
"""torch.testing._internal.fake_config_module.e_bool = False
176+
torch.testing._internal.fake_config_module.e_env_default = True
177+
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
178+
torch.testing._internal.fake_config_module.e_env_force = True
176179
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
177180
)
178181

torch/_dynamo/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# mypy: allow-untyped-defs
2-
import getpass
32
import inspect
43
import os
54
import re
65
import sys
7-
import tempfile
86
from os.path import abspath, dirname
97
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
108

@@ -437,10 +435,6 @@ def default_debug_dir_root():
437435
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
438436
if DEBUG_DIR_VAR_NAME in os.environ:
439437
return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
440-
elif is_fbcode():
441-
return os.path.join(
442-
tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
443-
)
444438
else:
445439
return os.path.join(os.getcwd(), "torch_compile_debug")
446440

torch/_dynamo/debug_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,31 @@ def _cuda_system_info_comment():
250250
return model_str
251251

252252

253+
def generate_env_vars_string(*, stable_output=False):
254+
"""
255+
Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton.
256+
"""
257+
if stable_output:
258+
return "# env var omitted due to stable_output=True"
259+
260+
allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"]
261+
skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"]
262+
263+
def filter(key):
264+
return any(string in key for string in allow_list) and key not in skip_list
265+
266+
config_lines = [
267+
f"os.environ['{key}'] = '{value}'"
268+
for key, value in os.environ.items()
269+
if filter(key)
270+
]
271+
config_string = "\n".join(config_lines)
272+
return f"""\
273+
import os
274+
{config_string}
275+
"""
276+
277+
253278
def generate_config_string(*, stable_output=False):
254279
import torch._functorch.config
255280
import torch._inductor.config

torch/_dynamo/repro/after_aot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
extra_deps,
2929
extra_imports,
3030
generate_config_string,
31+
generate_env_vars_string,
3132
helper_for_dump_minify,
3233
InputReader,
3334
InputWriter,
@@ -264,6 +265,7 @@ def generate_compiler_repro_string(
264265
):
265266
model_str = textwrap.dedent(
266267
f"""
268+
{generate_env_vars_string(stable_output=stable_output)}
267269
import torch
268270
from torch import tensor, device
269271
import torch.fx as fx

torch/_dynamo/repro/after_dynamo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
BuckTargetWriter,
2121
extra_imports,
2222
generate_config_string,
23+
generate_env_vars_string,
2324
helper_for_dump_minify,
2425
InputReader,
2526
InputWriter,
@@ -179,6 +180,7 @@ def generate_dynamo_fx_repro_string(
179180

180181
return textwrap.dedent(
181182
f"""
183+
{generate_env_vars_string(stable_output=stable_output)}
182184
from math import inf
183185
import torch
184186
from torch import tensor, device

torch/_dynamo/repro/aoti.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BuckTargetWriter,
1818
extra_imports,
1919
generate_config_string,
20+
generate_env_vars_string,
2021
helper_for_dump_minify,
2122
InputReader,
2223
minifier_dir,
@@ -193,6 +194,7 @@ def generate_compiler_repro_exported_program(
193194
):
194195
model_str = textwrap.dedent(
195196
f"""
197+
{generate_env_vars_string(stable_output=stable_output)}
196198
import torch
197199
import torch._inductor.inductor_prims
198200
@@ -455,7 +457,7 @@ def common_flags(parser):
455457
)
456458

457459
subparsers = parser.add_subparsers(
458-
dest="command", metavar="{run,minify,analyze}", required=True
460+
dest="command", metavar="{run,minify}", required=True
459461
)
460462

461463
parser_run = subparsers.add_parser(

torch/utils/_config_module.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,27 @@ def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
408408
setattr(module, constant_name, val)
409409

410410
def _is_default(self, name: str) -> bool:
411-
return self._config[name].user_override is _UNSET_SENTINEL
411+
"""
412+
Returns true if the config is at its default value.
413+
configs overriden by the env are not considered default.
414+
"""
415+
config_val = self._config[name]
416+
# The config is not overridden by the user, and the env_value_default
417+
# is different from the default value (meaning user has set the env to
418+
# change the default value).
419+
not_set_env_default = (
420+
config_val.env_value_default is _UNSET_SENTINEL
421+
or config_val.env_value_default == config_val.default
422+
)
423+
not_set_env_force = (
424+
config_val.env_value_force is _UNSET_SENTINEL
425+
or config_val.env_value_force == config_val.default
426+
)
427+
return (
428+
config_val.user_override is _UNSET_SENTINEL
429+
and not_set_env_default
430+
and not_set_env_force
431+
)
412432

413433
def _get_dict(
414434
self,

0 commit comments

Comments
 (0)