Skip to content

Commit 06e6207

Browse files
alanhdufacebook-github-bot
authored andcommitted
Update torchtnt for Python 3.12 (#964)
Summary: Pull Request resolved: #964 In Python 3.12, dataclass field defaults must be immutable, and implementing a custom `__eq__` on the enum will cause an error like: ``` ValueError: mutable default <enum 'StoppingMechanism'> for field stopping_mechanism is not allowed: use default_factory ``` To work around this, we can inherit from `StrEnum` (newly added into Python 3.12) which supports direct `==` comparisons against strings. For backwards compatibility, we define our own version of it as well. Reviewed By: diego-urgell Differential Revision: D68466161 fbshipit-source-id: 1c831cb4449adc953d921ee3559616d5c7b00f86
1 parent 3232a91 commit 06e6207

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

torchtnt/utils/data/iterators.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
from __future__ import annotations
1111

1212
import logging
13-
1413
import random
1514
from abc import abstractmethod
1615
from dataclasses import dataclass
17-
from enum import Enum
1816
from itertools import cycle
1917
from typing import (
2018
Any,
@@ -31,6 +29,16 @@
3129
Union,
3230
)
3331

32+
try:
33+
# pyre-ignore[21]: Could not find name `StrEnum` in `enum`
34+
from enum import StrEnum
35+
except ImportError:
36+
from enum import Enum
37+
38+
class StrEnum(str, Enum):
39+
pass
40+
41+
3442
import torch
3543
import torch.distributed as dist
3644

@@ -86,23 +94,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
8694
pass
8795

8896

89-
class StoppingMechanism(Enum):
97+
class StoppingMechanism(StrEnum):
9098
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
9199
SMALLEST_DATASET_EXHAUSTED = "SMALLEST_DATASET_EXHAUSTED"
92100
RESTART_UNTIL_ALL_DATASETS_EXHAUSTED = "RESTART_UNTIL_ALL_DATASETS_EXHAUSTED"
93101

94102
# used with RandomizedBatchSampler
95103
WRAP_AROUND_UNTIL_KILLED = "WRAP_AROUND_UNTIL_KILLED"
96104

97-
def __eq__(self, other: Union[str, StoppingMechanism]) -> bool:
98-
"""
99-
Enable comparison betwen string and instances of StoppingMechanism
100-
"""
101-
if isinstance(other, str):
102-
return self.value == other
103-
104-
return super().__eq__(other)
105-
106105

107106
@dataclass
108107
class RoundRobin(DataIterationStrategy):

0 commit comments

Comments
 (0)