Skip to content

Commit 30293b8

Browse files
narekmalkpytorchmergebot
authored andcommitted
Preserve Enum types during torch.export serialization and deserialization (pytorch#154821)
Fixes pytorch#154674 Addresses an issue where `torch.export` does not correctly preserve Python `Enum` types during the save/load round-trip. Previously, Enum inputs were serialized by value only, causing their type to be lost after deserialization. Pull Request resolved: pytorch#154821 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007, https://github.com/yushangdi, https://github.com/angelayi
1 parent 27df0c5 commit 30293b8

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

test/test_pytree.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,21 @@ class DirectNamedTuple2(NamedTuple):
859859
self.assertFalse(pytree.is_namedtuple(cls))
860860
self.assertFalse(pytree.is_namedtuple_class(cls))
861861

862+
@parametrize(
863+
"pytree",
864+
[
865+
subtest(py_pytree, name="py"),
866+
subtest(cxx_pytree, name="cxx"),
867+
],
868+
)
869+
def test_enum_treespec_roundtrip(self, pytree):
870+
data = {TestEnum.A: 5}
871+
spec = pytree.tree_structure(data)
872+
873+
serialized = pytree.treespec_dumps(spec)
874+
deserialized_spec = pytree.treespec_loads(serialized)
875+
self.assertEqual(spec, deserialized_spec)
876+
862877

863878
class TestPythonPytree(TestCase):
864879
def test_deprecated_register_pytree_node(self):

torch/utils/_pytree.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,14 @@ def get(self, parent: Any) -> Any:
113113

114114

115115
class EnumEncoder(json.JSONEncoder):
116-
def default(self, obj: object) -> str:
116+
def default(self, obj: object) -> Union[str, dict[str, Any]]:
117117
if isinstance(obj, Enum):
118-
return obj.value # type: ignore[no-any-return]
119-
return super().default(obj) # type: ignore[no-any-return]
118+
return {
119+
"__enum__": True,
120+
"fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
121+
"name": obj.name,
122+
}
123+
return cast(str, super().default(obj))
120124

121125

122126
Context = Any
@@ -1836,6 +1840,18 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
18361840
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
18371841

18381842

1843+
def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
1844+
if "__enum__" in obj:
1845+
modname, _, classname = obj["fqn"].partition(":")
1846+
mod = importlib.import_module(modname)
1847+
enum_cls = mod
1848+
for attr in classname.split("."):
1849+
enum_cls = getattr(enum_cls, attr)
1850+
enum_cls = cast(type[Enum], enum_cls)
1851+
return enum_cls[obj["name"]]
1852+
return obj
1853+
1854+
18391855
def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
18401856
if (
18411857
json_schema["type"] is None
@@ -1854,7 +1870,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
18541870

18551871
if serialize_node_def.from_dumpable_context is None:
18561872
try:
1857-
context = json.loads(json_schema["context"])
1873+
context = json.loads(json_schema["context"], object_hook=enum_object_hook)
18581874
except TypeError as ex:
18591875
raise TypeError(
18601876
"Unable to deserialize context. "

0 commit comments

Comments
 (0)