Skip to content

Commit d2bfd97

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Refactor pt2 save/load (pytorch#152495)
Refactor the pt2 archive saving to consolidate the format of torch.export.save and torch._inductor.package.package_aoti. This PR adds the following functions, which torch.export.save and AOTI packaging calls into: ```python package_pt2( f: FileLike, *, exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]] = None, aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None, extra_files: Optional[dict[str, Any]] = None, ) -> FileLike @DataClass class PT2ArchiveContents: exported_programs: dict[str, ExportedProgram] aoti_runners: dict[str, AOTICompiledModel] extra_files: dict[str, Any] load_pt2(f: FileLike) -> PT2ArchiveContents ``` Power users directly call into these APIs if they want to bundle multiple exported programs, aoti files, or extra metadata. This is how the pt2 archive looks like ([spec](https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0)): ``` ├── archive_format ├── version ├── .data ├── data │ ├── aotinductor │ │ └── model1 │ │ ├── model1.cpp │ │ ├── model1.so # currently AOTI automatically moves weights in here, TODO to move it out │ │ ├── cg7domx3woam3nnliwud7yvtcencqctxkvvcafuriladwxw4nfiv.cubin │ │ └── cubaaxppb6xmuqdm4bej55h2pftbce3bjyyvljxbtdfuolmv45ex.cubin │ ├── weights │ │ ├── model1.pt # TODO to dedup weights between model1/model2 │ │ └── model2.pt │ └── constants │ │ ├── model1.pt # TODO to dedup weights between model1/model2 │ │ └── model2.pt │ └── sample_inputs │ ├── model1.pt # TODO to dedup weights between model1/model2 │ └── model2.pt ├── extra │ └── user_metadata.txt └── models ├── model1.json └── model2.json ``` Future todos: - unbundle the weights -- instead of .pt, we can use bin files, which will also allow us to dedup weights if we store multiple models - update aoti_compile_and_package to also save the exported program - integrate TNR with this packaging flow Pull Request resolved: pytorch#152495 Approved by: https://github.com/yushangdi
1 parent 75b24c2 commit d2bfd97

File tree

5 files changed

+472
-167
lines changed

5 files changed

+472
-167
lines changed

test/export/test_serialize.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@
3434
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
3535
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
3636
from torch.export import Dim, export_for_training, load, save, unflatten
37+
from torch.export.pt2_archive.constants import ARCHIVE_VERSION_PATH
3738
from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges
3839
from torch.testing._internal.common_utils import (
3940
instantiate_parametrized_tests,
41+
IS_FBCODE,
42+
IS_MACOS,
4043
IS_WINDOWS,
4144
parametrize,
4245
run_tests,
@@ -1491,6 +1494,7 @@ def forward(self, x):
14911494

14921495
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
14931496

1497+
@unittest.skipIf(IS_WINDOWS, "Cannot modify file in windows")
14941498
def test_save_file(self):
14951499
class Foo(torch.nn.Module):
14961500
def forward(self, x):
@@ -1501,10 +1505,10 @@ def forward(self, x):
15011505
inp = (torch.randn(2, 2),)
15021506
ep = export_for_training(f, inp, strict=True)
15031507

1504-
with tempfile.NamedTemporaryFile() as f:
1505-
save(ep, f)
1508+
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
1509+
save(ep, f.name)
15061510
f.seek(0)
1507-
loaded_ep = load(f)
1511+
loaded_ep = load(f.name)
15081512

15091513
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
15101514

@@ -1518,7 +1522,7 @@ def forward(self, x, y):
15181522
inp = (torch.tensor([6]), torch.tensor([7]))
15191523
ep = export_for_training(f, inp, strict=True)
15201524

1521-
with TemporaryFileName() as fname:
1525+
with TemporaryFileName(suffix=".pt2") as fname:
15221526
path = Path(fname)
15231527
save(ep, path)
15241528
loaded_ep = load(path)
@@ -1545,6 +1549,9 @@ def forward(self, x):
15451549
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
15461550
self.assertEqual(extra_files["extra.txt"], "moo")
15471551

1552+
@unittest.skipIf(
1553+
IS_FBCODE or IS_MACOS or IS_WINDOWS, "The file path is different in fbcode CI"
1554+
)
15481555
def test_version_error(self):
15491556
class Foo(torch.nn.Module):
15501557
def forward(self, x):
@@ -1555,18 +1562,19 @@ def forward(self, x):
15551562
ep = export_for_training(f, (torch.randn(1, 3),), strict=True)
15561563

15571564
with self.assertRaisesRegex(
1558-
RuntimeError, r"Serialized version .* does not match our current"
1565+
ValueError, r"Saved archive version -1 does not match our current"
15591566
):
1560-
with tempfile.NamedTemporaryFile() as f:
1561-
save(ep, f)
1567+
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
1568+
save(ep, f.name)
15621569
f.seek(0)
1570+
file_prefix = f.name.split("/")[2].split(".")[0]
15631571

15641572
# Modify the version
15651573
with zipfile.ZipFile(f, "a") as zipf:
1566-
zipf.writestr("version", "-1.1")
1574+
zipf.writestr(f"{file_prefix}/{ARCHIVE_VERSION_PATH}", "-1")
15671575

15681576
f.seek(0)
1569-
load(f)
1577+
load(f.name)
15701578

15711579
def test_save_constants(self):
15721580
class Foo(torch.nn.Module):

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,7 @@ class PyTorchFileReader:
15611561
@overload
15621562
def __init__(self, buffer: IO[bytes]) -> None: ...
15631563
def get_record(self, name: str) -> bytes: ...
1564+
def get_all_records(self) -> list[str]: ...
15641565
def serialization_id(self) -> str: ...
15651566

15661567
class PyTorchFileWriter:

torch/_inductor/package/package.py

Lines changed: 16 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,12 @@
33
import logging
44
import os
55
import tempfile
6-
from typing import Any, IO, Optional, Union
6+
from typing import IO, Union
77

88
import torch
9-
import torch._inductor
10-
import torch.utils._pytree as pytree
119
from torch._inductor import config
1210
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
13-
from torch.export._tree_utils import reorder_kwargs
14-
from torch.export.pt2_archive._package import PT2ArchiveWriter
15-
from torch.export.pt2_archive.constants import (
16-
AOTINDUCTOR_DIR,
17-
CONSTANTS_DIR,
18-
CUSTOM_OBJ_FILENAME_PREFIX,
19-
)
11+
from torch.export.pt2_archive._package import AOTICompiledModel, load_pt2, package_pt2
2012
from torch.types import FileLike
2113

2214

@@ -95,122 +87,8 @@ def package_aoti(
9587
the AOTInductor files, or a dictionary mapping the model name to the
9688
path to its AOTInductor generated files.
9789
"""
98-
if isinstance(aoti_files, list):
99-
aoti_files = {"model": aoti_files}
100-
101-
assert isinstance(aoti_files, dict), (
102-
"Please pass a list of AOTI generated files to be packaged or "
103-
"a dictionary mapping model names to their list of AOTI generated "
104-
"files. You can get this list of files through calling "
105-
"`torch._inductor.aot_compile(..., options={aot_inductor.package=True})`"
106-
)
107-
assert (
108-
isinstance(archive_file, (io.IOBase, IO))
109-
and archive_file.writable()
110-
and archive_file.seekable()
111-
) or (
112-
isinstance(archive_file, (str, os.PathLike))
113-
and os.fspath(archive_file).endswith(".pt2")
114-
), (
115-
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
116-
)
11790

118-
# Save using the PT2 packaging format
119-
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
120-
121-
with PT2ArchiveWriter(archive_file) as archive_writer:
122-
for model_name, files in aoti_files.items():
123-
num_so_files = 0
124-
num_cpp_files = 0
125-
126-
for file in files:
127-
if file == "":
128-
continue
129-
130-
if file.endswith(".so"):
131-
num_so_files += 1
132-
if num_so_files > 1:
133-
raise RuntimeError(
134-
f"Multiple .so files found in {files}. "
135-
"You might need to clear your cache "
136-
"directory before calling aoti_compile again."
137-
)
138-
if file.endswith(".cpp"):
139-
num_cpp_files += 1
140-
if num_so_files > 1:
141-
raise RuntimeError(
142-
f"Multiple .cpp files found in {files}. "
143-
"You might need to clear your cache "
144-
"directory before calling aoti_compile again."
145-
)
146-
147-
filename = os.path.basename(file)
148-
if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX):
149-
new_filepath = os.path.join(CONSTANTS_DIR, filename)
150-
else:
151-
new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename)
152-
log.debug(
153-
"Saving AOTI generated file %s to archive in %s", file, new_filepath
154-
)
155-
archive_writer.write_file(
156-
str(new_filepath),
157-
file,
158-
)
159-
160-
if isinstance(archive_file, (io.IOBase, IO)):
161-
archive_file.seek(0)
162-
return archive_file
163-
164-
165-
class AOTICompiledModel:
166-
"""
167-
Callable AOT Inductor loaded model from a .pt2
168-
"""
169-
170-
def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None:
171-
self.loader = loader
172-
173-
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
174-
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
175-
in_spec = pytree.treespec_loads(call_spec[0])
176-
out_spec = pytree.treespec_loads(call_spec[1])
177-
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
178-
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
179-
flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined]
180-
return pytree.tree_unflatten(flat_outputs, out_spec)
181-
182-
def get_metadata(self) -> dict[str, str]:
183-
return self.loader.get_metadata() # type: ignore[attr-defined]
184-
185-
def load_constants(
186-
self,
187-
constants_map: dict[str, torch.Tensor],
188-
*,
189-
check_full_update: bool,
190-
user_managed: bool = False,
191-
) -> None:
192-
"""
193-
Given a mapping of constant fqns to tensors, load the constants into the model.
194-
You can use ``get_constant_fqns`` to get the list of constant fqns that
195-
are needed in the compiled model.
196-
197-
Args:
198-
constants_map: A mapping of constant fqns to tensors.
199-
check_full_update: Whether to add check to see if all the constants
200-
are updated and have values.
201-
"""
202-
self.loader.load_constants( # type: ignore[attr-defined]
203-
constants_map, False, check_full_update, user_managed
204-
)
205-
206-
def get_constant_fqns(self) -> list[str]:
207-
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
208-
209-
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
210-
log.warning(
211-
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
212-
)
213-
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
91+
return package_pt2(archive_file, aoti_files=aoti_files)
21492

21593

21694
def load_package(
@@ -220,18 +98,25 @@ def load_package(
22098
num_runners: int = 1,
22199
device_index: int = -1,
222100
) -> AOTICompiledModel: # type: ignore[type-arg]
223-
assert (
224-
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
225-
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
226-
f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
227-
)
101+
try:
102+
pt2_contents = load_pt2(
103+
path,
104+
run_single_threaded=run_single_threaded,
105+
num_runners=num_runners,
106+
device_index=device_index,
107+
)
108+
if model_name not in pt2_contents.aoti_runners:
109+
raise RuntimeError(f"Model {model_name} not found in package")
110+
return pt2_contents.aoti_runners[model_name]
111+
except RuntimeError:
112+
log.warning("Loading outdated pt2 file. Please regenerate your package.")
228113

229114
if isinstance(path, (io.IOBase, IO)):
230115
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
231116
# TODO(angelayi): We shouldn't need to do this -- miniz should
232117
# handle reading the buffer. This is just a temporary workaround
233-
f.write(path.read())
234118
path.seek(0)
119+
f.write(path.read())
235120
log.debug("Writing buffer to tmp file located at %s.", f.name)
236121
loader = torch._C._aoti.AOTIModelPackageLoader(
237122
f.name, model_name, run_single_threaded, num_runners, device_index

torch/export/__init__.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -381,29 +381,15 @@ def forward(self, x):
381381
f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead."
382382
)
383383

384-
from torch._export.serde.schema import SCHEMA_VERSION
385-
from torch._export.serde.serialize import serialize, SerializedArtifact
386-
387-
artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol)
388-
389-
if isinstance(f, (str, os.PathLike)):
390-
f = os.fspath(f)
391-
392-
with zipfile.ZipFile(f, "w") as zipf:
393-
# Save every field in the SerializedArtifact to a file.
394-
assert isinstance(artifact.exported_program, bytes)
395-
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
396-
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
397-
zipf.writestr("serialized_constants.pt", artifact.constants)
398-
zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
399-
400-
zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
401-
402-
# Add extra files if provided
403-
if extra_files:
404-
for extra_file_name, content in extra_files.items():
405-
encoded_content = content.encode("utf-8")
406-
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
384+
from torch.export.pt2_archive._package import package_pt2
385+
386+
package_pt2(
387+
f,
388+
exported_programs={"model": ep},
389+
extra_files=extra_files,
390+
pickle_protocol=pickle_protocol,
391+
opset_version=opset_version,
392+
)
407393

408394

409395
def load(
@@ -460,10 +446,32 @@ def load(
460446

461447
extra_files = extra_files or {}
462448

449+
from torch.export.pt2_archive._package import load_pt2, PT2ArchiveContents
450+
451+
try:
452+
pt2_contents = load_pt2(
453+
f,
454+
expected_opset_version=expected_opset_version,
455+
)
456+
except RuntimeError:
457+
pt2_contents = PT2ArchiveContents({}, {}, {})
458+
459+
if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0:
460+
for k, v in pt2_contents.extra_files.items():
461+
extra_files[k] = v
462+
463+
return pt2_contents.exported_programs["model"]
464+
465+
# TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?)
466+
warnings.warn(
467+
"This version of file is deprecated. Please generate a new pt2 saved file."
468+
)
463469
with zipfile.ZipFile(f, "r") as zipf:
464470
# Check the version
465471
version = zipf.read("version").decode().split(".")
466-
from torch._export.serde.schema import SCHEMA_VERSION
472+
from torch._export.serde.schema import (
473+
SCHEMA_VERSION, # todo change archive version to schema version
474+
)
467475

468476
assert len(version) == len(SCHEMA_VERSION)
469477
if version[0] != str(SCHEMA_VERSION[0]):

0 commit comments

Comments
 (0)