Skip to content

Commit 7e80f23

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Move PT2ArchiveWriter/Reader to torch/export (pytorch#153795)
Summary: Before: `from sigmoid.core.package.pt2_archive import PT2ArchiveWriter, PT2ArchiveReader, is_sigmoid_package` After: `from torch.export.pt2_archive import PT2ArchiveWriter, PT2ArchiveReader, is_pt2_package` By merging the two PT2ArchiveReader/Writers, into using the native PytorchFileReader/Writer, the open source PT2 archive also changed to have an additional folder. However this PR still maintains support for loading an old PT2 archive which does not have the additional folder. Before: ``` ├── archive_format ├── byteorder ├── .data │ ├── serialization_id │ └── version ├── data │ ├── aotinductor ``` After: ``` ├── tmp │ ├── archive_format │ ├── byteorder │ ├── .data │ │ ├── serialization_id │ │ └── version │ ├── data │ │ ├── aotinductor ``` Test Plan: `buck2 test //sigmoid/...` https://www.internalfb.com/intern/testinfra/testrun/5348024839248187 Differential Revision: D74616598 Pull Request resolved: pytorch#153795 Approved by: https://github.com/zhxchen17
1 parent 214e4ce commit 7e80f23

File tree

6 files changed

+231
-85
lines changed

6 files changed

+231
-85
lines changed

test/allowlist_for_publicAPI.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2650,6 +2650,11 @@
26502650
"torch.export.graph_signature": [
26512651
"TokenArgument"
26522652
],
2653+
"torch.export.pt2_archive": [
2654+
"PT2ArchiveWriter",
2655+
"PT2ArchiveReader",
2656+
"is_pt2_package"
2657+
],
26532658
"torch.fx.experimental.shape_inference.infer_shape": [
26542659
"DimDynamic",
26552660
"FakeTensorMode",

test/inductor/test_aot_inductor_package.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,10 @@ def forward(self, x, y):
216216
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
217217
package_path, "r"
218218
) as zip_ref:
219+
filenames = zip_ref.namelist()
220+
prefix = filenames[0].split("/")[0]
219221
zip_ref.extractall(tmp_dir)
220-
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
222+
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
221223
self.assertTrue(tmp_path.exists())
222224
if self.device == GPU_TYPE:
223225
kernel_bin = get_kernel_bin_format(self.device)

torch/_inductor/package/package.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@
33
import logging
44
import os
55
import tempfile
6-
import zipfile
7-
from pathlib import Path
86
from typing import Any, IO, Optional, Union
9-
from typing_extensions import Self
107

118
import torch
129
import torch._inductor
1310
import torch.utils._pytree as pytree
1411
from torch._inductor import config
1512
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
1613
from torch.export._tree_utils import reorder_kwargs
14+
from torch.export.pt2_archive._package import PT2ArchiveWriter
1715
from torch.export.pt2_archive.constants import (
1816
AOTINDUCTOR_DIR,
19-
ARCHIVE_VERSION_VALUE,
2017
CONSTANTS_DIR,
2118
CUSTOM_OBJ_FILENAME_PREFIX,
2219
)
@@ -26,74 +23,6 @@
2623
log = logging.getLogger(__name__)
2724

2825

29-
class PT2ArchiveWriter:
30-
def __init__(self, archive_path: FileLike) -> None:
31-
self.archive_path: FileLike = archive_path
32-
self.archive_file: Optional[zipfile.ZipFile] = None
33-
34-
def __enter__(self) -> Self:
35-
assert self.archive_file is None
36-
self.archive_file = zipfile.ZipFile(
37-
self.archive_path, "w", compression=zipfile.ZIP_STORED
38-
)
39-
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
40-
self.writestr("archive_format", "pt2")
41-
return self
42-
43-
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
44-
assert self.archive_file is not None
45-
self.archive_file.close()
46-
self.archive_file = None
47-
return None
48-
49-
def writestr(self, name: str, data: Union[bytes, str]) -> None:
50-
assert self.archive_file is not None
51-
self.archive_file.writestr(name, data)
52-
53-
def write_file(self, name: str, file_path: str) -> None:
54-
"""
55-
Copy a file into the archive.
56-
name: The destination file inside the archive.
57-
file_path: The source file on disk.
58-
"""
59-
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
60-
assert self.archive_file is not None
61-
self.archive_file.write(file_path, arcname=name)
62-
63-
64-
class PT2ArchiveReader:
65-
def __init__(self, archive_path: str) -> None:
66-
self.archive_path: str = archive_path
67-
self.archive_file: Optional[zipfile.ZipFile] = None
68-
69-
def __enter__(self) -> Self:
70-
self.archive_file = zipfile.ZipFile(
71-
self.archive_path, "r", compression=zipfile.ZIP_STORED
72-
)
73-
return self
74-
75-
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
76-
if self.archive_file is not None:
77-
self.archive_file.close()
78-
return None
79-
80-
def read(self, name: str) -> bytes:
81-
assert self.archive_file is not None
82-
return self.archive_file.read(name)
83-
84-
def extract_to_path(self, member: str, path: str) -> str:
85-
assert self.archive_file is not None
86-
return self.archive_file.extract(member, path)
87-
88-
def extractall(self, path: str) -> None:
89-
assert self.archive_file is not None
90-
self.archive_file.extractall(path)
91-
92-
def get_file_names(self) -> list[str]:
93-
assert self.archive_file is not None
94-
return self.archive_file.namelist()
95-
96-
9726
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
9827
def get_aoti_file_with_suffix(suffix: str) -> str:
9928
for file in aoti_files:

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -367,15 +367,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
367367
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
368368
}
369369

370-
temp_dir_ = create_temp_dir();
371-
std::string so_filename;
372-
std::string cpp_filename;
373-
std::vector<std::string> obj_filenames;
374-
std::string found_filenames; // Saving for bookkeeping
375-
std::string model_directory =
376-
"data" + k_separator + "aotinductor" + k_separator + model_name;
377-
std::string const_directory = "data" + k_separator + "constants";
378-
370+
std::vector<std::string> found_filenames;
379371
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
380372
uint32_t filename_len =
381373
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
@@ -389,10 +381,40 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
389381
&zip_archive, i, filename_str.data(), filename_len)) {
390382
throw std::runtime_error("Failed to read filename");
391383
}
384+
found_filenames.push_back(filename_str);
385+
}
386+
387+
if (found_filenames.empty()) {
388+
throw std::runtime_error("No files found in zip archive.");
389+
}
390+
391+
// All the paths are prepended with a tmp/ directory. We need to find the
392+
// prefix.
393+
std::string file_prefix;
394+
size_t pos = found_filenames[0].find('/');
395+
std::string prefix0 = found_filenames[0].substr(0, pos);
396+
pos = found_filenames[1].find('/');
397+
std::string prefix1 = found_filenames[1].substr(0, pos);
392398

393-
found_filenames += filename_str;
394-
found_filenames += " ";
399+
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
400+
file_prefix = prefix0 + "/";
401+
} else {
402+
LOG(WARNING)
403+
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
404+
<< found_filenames[0] << "\n"
405+
<< found_filenames[1];
406+
}
407+
408+
temp_dir_ = create_temp_dir();
409+
410+
std::string so_filename;
411+
std::string cpp_filename;
412+
std::vector<std::string> obj_filenames;
413+
std::string model_directory = file_prefix + "data" + k_separator +
414+
"aotinductor" + k_separator + model_name;
415+
std::string const_directory = "data" + k_separator + "constants";
395416

417+
for (const std::string& filename_str : found_filenames) {
396418
// Only compile files in the specified model directory
397419
if (c10::starts_with(filename_str, model_directory) ||
398420
c10::starts_with(filename_str, const_directory)) {
@@ -460,9 +482,13 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
460482
}
461483

462484
if (cpp_filename.empty() && so_filename.empty()) {
485+
std::string found_filenames_str;
486+
for (const std::string& filename : found_filenames) {
487+
found_filenames_str += filename + "\n";
488+
}
463489
throw std::runtime_error(
464490
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
465-
found_filenames);
491+
found_filenames_str);
466492
}
467493

468494
// Compile the .so
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
2+
3+
4+
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# pyre-unsafe
2+
3+
import glob
4+
import io
5+
import logging
6+
import os
7+
import zipfile
8+
from typing import Any, Union
9+
10+
import torch
11+
from torch.export.pt2_archive.constants import (
12+
ARCHIVE_FORMAT_PATH,
13+
ARCHIVE_FORMAT_VALUE,
14+
ARCHIVE_VERSION_PATH,
15+
ARCHIVE_VERSION_VALUE,
16+
)
17+
from torch.types import FileLike
18+
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
23+
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
24+
"""
25+
Check if the serialized model is a PT2 Archive package.
26+
"""
27+
try:
28+
zip_reader = zipfile.ZipFile(
29+
io.BytesIO(serialized_model)
30+
if isinstance(serialized_model, bytes)
31+
else serialized_model
32+
)
33+
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
34+
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
35+
if archive_format_path in zip_reader.namelist():
36+
return zip_reader.read(archive_format_path) == b"pt2"
37+
except Exception as ex:
38+
logger.info("Model is not a PT2 package: %s", str(ex))
39+
return False
40+
41+
42+
class PT2ArchiveWriter:
43+
"""
44+
Context manager for writing a PT2 archive.
45+
"""
46+
47+
def __init__(self, archive_path_or_buffer: FileLike):
48+
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
49+
# NOTICE: version here is different from the archive_version
50+
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
51+
# archive_version is the version of the PT2 archive spec, which write to /archive_version
52+
self.archive_file.set_min_version(6)
53+
54+
def __enter__(self) -> "PT2ArchiveWriter":
55+
return self
56+
57+
def __exit__(self, *args: Any) -> None:
58+
if not self.has_record(ARCHIVE_FORMAT_PATH):
59+
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
60+
61+
if not self.has_record(ARCHIVE_VERSION_PATH):
62+
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
63+
64+
self.close()
65+
66+
def has_record(self, name: str) -> bool:
67+
"""
68+
Check if a record exists in the archive.
69+
"""
70+
return name in self.archive_file.get_all_written_records()
71+
72+
def count_prefix(self, prefix: str) -> int:
73+
"""
74+
Count the number of records that start with a given prefix.
75+
"""
76+
return sum(
77+
1
78+
for record in self.archive_file.get_all_written_records()
79+
if record.startswith(prefix)
80+
)
81+
82+
def write_bytes(self, name: str, data: bytes) -> None:
83+
"""
84+
Write a bytes object to the archive.
85+
name: The destination file inside the archive.
86+
data: The bytes object to write.
87+
"""
88+
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
89+
self.archive_file.write_record(name, data, len(data))
90+
91+
def write_string(self, name: str, data: str) -> None:
92+
"""
93+
Write a string object to the archive.
94+
name: The destination file inside the archive.
95+
data: The string object to write.
96+
"""
97+
assert isinstance(data, str), f"Expected string but got {type(data)}"
98+
data_bytes = data.encode()
99+
self.write_bytes(name, data_bytes)
100+
101+
def write_file(self, name: str, file_path: str) -> None:
102+
"""
103+
Copy a file into the archive.
104+
name: The destination file inside the archive.
105+
file_path: The source file on disk.
106+
"""
107+
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
108+
109+
with open(file_path, "rb") as f:
110+
file_bytes = f.read()
111+
self.write_bytes(name, file_bytes)
112+
113+
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
114+
"""
115+
Copy a folder into the archive.
116+
archive_dir: The destination folder inside the archive.
117+
folder_dir: The source folder on disk.
118+
"""
119+
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
120+
121+
file_paths = filter(
122+
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
123+
)
124+
for file_path in file_paths:
125+
filename = os.path.relpath(file_path, folder_dir)
126+
archive_path = os.path.join(archive_dir, filename)
127+
self.write_file(archive_path, file_path)
128+
129+
def close(self) -> None:
130+
"""
131+
Close the archive.
132+
"""
133+
self.archive_file.write_end_of_file()
134+
135+
136+
class PT2ArchiveReader:
137+
"""
138+
Context manager for reading a PT2 archive.
139+
"""
140+
141+
def __init__(self, archive_path_or_buffer: FileLike):
142+
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
143+
assert (
144+
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
145+
), "Invalid archive format"
146+
147+
def __enter__(self) -> "PT2ArchiveReader":
148+
return self
149+
150+
def __exit__(self, *args: Any) -> None:
151+
# torch._C.PyTorchFileReader doesn't have a close method
152+
pass
153+
154+
def read_bytes(self, name: str) -> bytes:
155+
"""
156+
Read a bytes object from the archive.
157+
name: The source file inside the archive.
158+
"""
159+
return self.archive_file.get_record(name)
160+
161+
def read_string(self, name: str) -> str:
162+
"""
163+
Read a string object from the archive.
164+
name: The source file inside the archive.
165+
"""
166+
data = self.read_bytes(name)
167+
return data.decode()
168+
169+
def archive_version(self) -> int:
170+
"""
171+
Get the archive version.
172+
"""
173+
try:
174+
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
175+
except Exception:
176+
# if archive_version is not found, it means the archive is older than version 0.
177+
# In this case, we assume the archive is version 0.
178+
archive_version = "0"
179+
180+
return int(archive_version)

0 commit comments

Comments
 (0)