Skip to content

Commit 26f066b

Browse files
yushangdipytorchmergebot
authored andcommitted
Add AOTI model name config (pytorch#154129)
Summary: If a model name is specified in aoti config, the generated files will use that model name as file stem. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_using_model_name_for_files ``` Bifferential Revision: D75102034 Pull Request resolved: pytorch#154129 Approved by: https://github.com/desertfire
1 parent fa705f7 commit 26f066b

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import tempfile
77
import unittest
8+
import zipfile
89
from unittest import skip
910
from unittest.mock import patch
1011

@@ -5971,6 +5972,39 @@ def forward(self, x):
59715972
# the output should have int type
59725973
self.check_model(Model2(), (x,))
59735974

5975+
def test_using_model_name_for_files(self):
5976+
class Model(torch.nn.Module):
5977+
def __init__(self) -> None:
5978+
super().__init__()
5979+
self.linear = torch.nn.Linear(10, 10)
5980+
5981+
def forward(self, x, y):
5982+
return x + self.linear(y)
5983+
5984+
example_inputs = (
5985+
torch.randn(10, 10, device=self.device),
5986+
torch.randn(10, 10, device=self.device),
5987+
)
5988+
model = Model().to(self.device)
5989+
with torch.no_grad():
5990+
package_path: str = AOTIRunnerUtil.compile(
5991+
model,
5992+
example_inputs,
5993+
inductor_configs={
5994+
"aot_inductor.model_name_for_generated_files": "test_model"
5995+
},
5996+
)
5997+
5998+
with zipfile.ZipFile(package_path, "r") as zip_ref:
5999+
all_files = zip_ref.namelist()
6000+
base_dir = "test_model.wrapper/data/aotinductor/model/test_model"
6001+
self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files)
6002+
self.assertTrue(f"{base_dir}.kernel.cpp" in all_files)
6003+
self.assertTrue(f"{base_dir}.wrapper.so" in all_files)
6004+
6005+
aot_inductor_module = torch._inductor.aoti_load_package(package_path)
6006+
self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs))
6007+
59746008

59756009
class AOTInductorLoggingTest(LoggingTestCase):
59766010
@make_logging_test(dynamic=logging.DEBUG)

torch/_inductor/codecache.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,10 @@ def get_keys(cls) -> KeysView[str]:
16151615

16161616

16171617
class AotCodeCompiler:
1618+
"""
1619+
Compile AOT Inductor generated code.
1620+
"""
1621+
16181622
@classmethod
16191623
def compile(
16201624
cls,
@@ -1676,6 +1680,7 @@ def compile(
16761680
"wrapper.cpp",
16771681
extra=cpp_command,
16781682
specified_dir=specified_output_path,
1683+
key=config.aot_inductor.model_name_for_generated_files,
16791684
)
16801685
kernel_code = (
16811686
f"// Triton kernels are embedded as comments in {wrapper_path}\n"
@@ -1686,6 +1691,7 @@ def compile(
16861691
"kernel.cpp",
16871692
extra=cpp_command,
16881693
specified_dir=specified_output_path,
1694+
key=config.aot_inductor.model_name_for_generated_files,
16891695
)
16901696

16911697
# Log the AOTInductor wrapper and kernel code, if needed.

torch/_inductor/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,10 @@ class triton:
12571257

12581258

12591259
class aot_inductor:
1260+
"""
1261+
Settings for Ahead-Of-Time Inductor Compilation
1262+
"""
1263+
12601264
# AOTInductor output path
12611265
# If an absolute path is specified, the generated lib files will be stored under the directory;
12621266
# If a relative path is specified, it will be used as a subdirectory under the default caching path;
@@ -1357,6 +1361,10 @@ class aot_inductor:
13571361
# In addition to emit asm files, also emit binary files for current arch
13581362
emit_current_arch_binary: bool = False
13591363

1364+
# If not None, the generated files with use this name in file stem.
1365+
# If None, we will use a hash to name files.
1366+
model_name_for_generated_files: Optional[str] = None
1367+
13601368
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
13611369
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
13621370
# custom op libs that have implemented C shim wrappers

0 commit comments

Comments
 (0)