Skip to content

Commit 7a40353

Browse files
ethanbwaitefacebook-github-bot
authored andcommitted
Add from_str methods to enums (#1095)
Summary: Adds from_str methods to help parse enums from string literals Reviewed By: hokyungh-m Differential Revision: D79773872
1 parent 0255f71 commit 7a40353

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ torch>=2.7.0
2929
torchmetrics==1.6.3
3030
torchserve>=0.10.0
3131
torchtext==0.18.0
32-
torchvision==0.22.0
32+
torchvision==0.23.0
3333
typing-extensions
3434
ts==0.5.1
3535
ray[default]

torchx/specs/api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,32 @@ class RetryPolicy(str, Enum):
264264
APPLICATION = "APPLICATION"
265265
ROLE = "ROLE"
266266

267+
@staticmethod
268+
def from_str(s: str) -> "RetryPolicy":
269+
"""
270+
Returns the retry policy for the given string.
271+
"""
272+
try:
273+
return RetryPolicy[s]
274+
except KeyError as e:
275+
raise ValueError(f"Invalid retry policy: {s}") from e
276+
267277

268278
class MountType(str, Enum):
269279
BIND = "bind"
270280
VOLUME = "volume"
271281
DEVICE = "device"
272282

283+
@staticmethod
284+
def from_str(s: str) -> "MountType":
285+
"""
286+
Returns the mount type for the given string.
287+
"""
288+
try:
289+
return MountType[s]
290+
except KeyError as e:
291+
raise ValueError(f"Invalid mount type: {s}") from e
292+
273293

274294
@dataclass
275295
class BindMount:
@@ -481,6 +501,16 @@ def __str__(self) -> str:
481501
def __repr__(self) -> str:
482502
return f"{self.name} ({self.value})"
483503

504+
@staticmethod
505+
def from_str(s: str) -> "AppState":
506+
"""
507+
Returns the app state for the given string.
508+
"""
509+
try:
510+
return AppState[s]
511+
except KeyError as e:
512+
raise ValueError(f"Invalid app state: {s}") from e
513+
484514

485515
_TERMINAL_STATES: List[AppState] = [
486516
AppState.SUCCEEDED,

torchx/specs/test/api_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
macros,
3232
MalformedAppHandleException,
3333
MISSING,
34+
MountType,
3435
NULL_RESOURCE,
3536
parse_app_handle,
3637
ReplicaStatus,
@@ -55,6 +56,20 @@ def test_repr(self) -> None:
5556

5657

5758
class AppDefStatusTest(unittest.TestCase):
59+
60+
def test_app_state_from_str(self) -> None:
61+
self.assertEqual(AppState.UNSUBMITTED, AppState.from_str("UNSUBMITTED"))
62+
self.assertEqual(AppState.SUBMITTED, AppState.from_str("SUBMITTED"))
63+
self.assertEqual(AppState.PENDING, AppState.from_str("PENDING"))
64+
self.assertEqual(AppState.RUNNING, AppState.from_str("RUNNING"))
65+
self.assertEqual(AppState.SUCCEEDED, AppState.from_str("SUCCEEDED"))
66+
self.assertEqual(AppState.FAILED, AppState.from_str("FAILED"))
67+
self.assertEqual(AppState.CANCELLED, AppState.from_str("CANCELLED"))
68+
self.assertEqual(AppState.UNKNOWN, AppState.from_str("UNKNOWN"))
69+
70+
with self.assertRaises(ValueError):
71+
AppState.from_str("INVALID_STATE")
72+
5873
def test_is_terminal(self) -> None:
5974
for s in AppState:
6075
is_terminal = AppStatus(state=s).is_terminal()
@@ -315,6 +330,26 @@ def test_retry_policies(self) -> None:
315330
},
316331
)
317332

333+
def test_retry_policy_from_str(self) -> None:
334+
# Test valid retry policy strings
335+
self.assertEqual(RetryPolicy.APPLICATION, RetryPolicy.from_str("APPLICATION"))
336+
self.assertEqual(RetryPolicy.REPLICA, RetryPolicy.from_str("REPLICA"))
337+
self.assertEqual(RetryPolicy.ROLE, RetryPolicy.from_str("ROLE"))
338+
339+
# Test invalid retry policy string
340+
with self.assertRaises(ValueError):
341+
RetryPolicy.from_str("INVALID_POLICY")
342+
343+
def test_mount_type_from_str(self) -> None:
344+
# Test valid mount type strings
345+
self.assertEqual(MountType.BIND, MountType.from_str("BIND"))
346+
self.assertEqual(MountType.VOLUME, MountType.from_str("VOLUME"))
347+
self.assertEqual(MountType.DEVICE, MountType.from_str("DEVICE"))
348+
349+
# Test invalid mount type string
350+
with self.assertRaises(ValueError):
351+
MountType.from_str("INVALID_MOUNT_TYPE")
352+
318353
def test_override_role(self) -> None:
319354
default = Role(
320355
"foobar",

0 commit comments

Comments
 (0)