diff --git a/dev-requirements.txt b/dev-requirements.txt index f6dec1675..0abe53c1d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -29,7 +29,7 @@ torch>=2.7.0 torchmetrics==1.6.3 torchserve>=0.10.0 torchtext==0.18.0 -torchvision==0.22.0 +torchvision==0.23.0 typing-extensions ts==0.5.1 ray[default] diff --git a/torchx/specs/api.py b/torchx/specs/api.py index e3e954a5b..44d21cbff 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -264,12 +264,32 @@ class RetryPolicy(str, Enum): APPLICATION = "APPLICATION" ROLE = "ROLE" + @staticmethod + def from_str(s: str) -> "RetryPolicy": + """ + Returns the retry policy for the given string. + """ + try: + return RetryPolicy[s] + except KeyError as e: + raise ValueError(f"Invalid retry policy: {s}") from e + class MountType(str, Enum): BIND = "bind" VOLUME = "volume" DEVICE = "device" + @staticmethod + def from_str(s: str) -> "MountType": + """ + Returns the mount type for the given string. + """ + try: + return MountType[s] + except KeyError as e: + raise ValueError(f"Invalid mount type: {s}") from e + @dataclass class BindMount: @@ -481,6 +501,16 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"{self.name} ({self.value})" + @staticmethod + def from_str(s: str) -> "AppState": + """ + Returns the app state for the given string. + """ + try: + return AppState[s] + except KeyError as e: + raise ValueError(f"Invalid app state: {s}") from e + _TERMINAL_STATES: List[AppState] = [ AppState.SUCCEEDED, diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 2490e89b2..cb48652d6 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -31,6 +31,7 @@ macros, MalformedAppHandleException, MISSING, + MountType, NULL_RESOURCE, parse_app_handle, ReplicaStatus, @@ -55,6 +56,20 @@ def test_repr(self) -> None: class AppDefStatusTest(unittest.TestCase): + + def test_app_state_from_str(self) -> None: + self.assertEqual(AppState.UNSUBMITTED, AppState.from_str("UNSUBMITTED")) + self.assertEqual(AppState.SUBMITTED, AppState.from_str("SUBMITTED")) + self.assertEqual(AppState.PENDING, AppState.from_str("PENDING")) + self.assertEqual(AppState.RUNNING, AppState.from_str("RUNNING")) + self.assertEqual(AppState.SUCCEEDED, AppState.from_str("SUCCEEDED")) + self.assertEqual(AppState.FAILED, AppState.from_str("FAILED")) + self.assertEqual(AppState.CANCELLED, AppState.from_str("CANCELLED")) + self.assertEqual(AppState.UNKNOWN, AppState.from_str("UNKNOWN")) + + with self.assertRaises(ValueError): + AppState.from_str("INVALID_STATE") + def test_is_terminal(self) -> None: for s in AppState: is_terminal = AppStatus(state=s).is_terminal() @@ -315,6 +330,26 @@ def test_retry_policies(self) -> None: }, ) + def test_retry_policy_from_str(self) -> None: + # Test valid retry policy strings + self.assertEqual(RetryPolicy.APPLICATION, RetryPolicy.from_str("APPLICATION")) + self.assertEqual(RetryPolicy.REPLICA, RetryPolicy.from_str("REPLICA")) + self.assertEqual(RetryPolicy.ROLE, RetryPolicy.from_str("ROLE")) + + # Test invalid retry policy string + with self.assertRaises(ValueError): + RetryPolicy.from_str("INVALID_POLICY") + + def test_mount_type_from_str(self) -> None: + # Test valid mount type strings + self.assertEqual(MountType.BIND, MountType.from_str("BIND")) + self.assertEqual(MountType.VOLUME, MountType.from_str("VOLUME")) + self.assertEqual(MountType.DEVICE, MountType.from_str("DEVICE")) + + # Test invalid mount type string + with self.assertRaises(ValueError): + MountType.from_str("INVALID_MOUNT_TYPE") + def test_override_role(self) -> None: default = Role( "foobar",