diff --git a/dev-requirements.txt b/dev-requirements.txt index f6dec1675..0abe53c1d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -29,7 +29,7 @@ torch>=2.7.0 torchmetrics==1.6.3 torchserve>=0.10.0 torchtext==0.18.0 -torchvision==0.22.0 +torchvision==0.23.0 typing-extensions ts==0.5.1 ray[default] diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 4e3659514..787ee13dd 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -207,8 +207,7 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: " (e.g. `local_cwd`)" ) - scheduler_opts = runner.scheduler_run_opts(args.scheduler) - cfg = scheduler_opts.cfg_from_str(args.scheduler_args) + cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args)) config.apply(scheduler=args.scheduler, cfg=cfg) component, component_args = _parse_component_name_and_args( @@ -263,12 +262,14 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: sys.exit(1) except specs.InvalidRunConfigException as e: error_msg = ( - f"Scheduler arg is incorrect or missing required option: `{e.cfg_key}`\n" - f"Run `torchx runopts` to check configuration for `{args.scheduler}` scheduler\n" - f"Use `-cfg` to specify run cfg as `key1=value1,key2=value2` pair\n" - "of setup `.torchxconfig` file, see: https://pytorch.org/torchx/main/experimental/runner.config.html" + "Invalid scheduler configuration: %s\n" + "To configure scheduler options, either:\n" + " 1. Use the `-cfg` command-line argument, e.g., `-cfg key1=value1,key2=value2`\n" + " 2. Set up a `.torchxconfig` file. For more details, visit: https://pytorch.org/torchx/main/runner.config.html\n" + "Run `torchx runopts %s` to check all available configuration options for the " + "`%s` scheduler." ) - logger.error(error_msg) + print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr) sys.exit(1) def run(self, args: argparse.Namespace) -> None: diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 67b0cbe75..6ec366303 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -486,6 +486,27 @@ def scheduler_run_opts(self, scheduler: str) -> runopts: """ return self._scheduler(scheduler).run_opts() + def cfg_from_str(self, scheduler: str, *cfg_literal: str) -> Mapping[str, CfgVal]: + """ + Convenience function around the scheduler's ``runopts.cfg_from_str()`` method. + + Usage: + + .. doctest:: + + from torchx.runner import get_runner + + runner = get_runner() + cfg = runner.cfg_from_str("local_cwd", "log_dir=/tmp/foobar", "prepend_cwd=True") + assert cfg == {"log_dir": "/tmp/foobar", "prepend_cwd": True, "auto_set_cuda_visible_devices": False} + """ + + opts = self._scheduler(scheduler).run_opts() + cfg = {} + for cfg_str in cfg_literal: + cfg.update(opts.cfg_from_str(cfg_str)) + return cfg + def scheduler_backends(self) -> List[str]: """ Returns a list of all supported scheduler backends. diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index c29ee35c9..16181e728 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -28,6 +28,7 @@ parse_app_handle, Resource, Role, + runopts, UnknownAppException, ) from torchx.specs.finder import ComponentNotFoundException @@ -701,3 +702,36 @@ def test_runner_manual_close(self, _) -> None: def test_get_default_runner(self, _) -> None: runner = get_runner() self.assertEqual("torchx", runner._name) + + def test_cfg_from_str(self, _) -> None: + scheduler_mock = MagicMock() + opts = runopts() + opts.add("foo", type_=str, default="", help="") + opts.add("test_key", type_=str, default="", help="") + opts.add("default_time", type_=int, default=0, help="") + opts.add("enable", type_=bool, default=True, help="") + opts.add("disable", type_=bool, default=True, help="") + opts.add("complex_list", type_=List[str], default=[], help="") + scheduler_mock.run_opts.return_value = opts + + with Runner( + name=SESSION_NAME, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, + ) as runner: + self.assertDictEqual( + { + "foo": "bar", + "test_key": "test_value", + "default_time": 42, + "enable": True, + "disable": False, + "complex_list": ["v1", "v2", "v3"], + }, + runner.cfg_from_str( + "local_dir", + "foo=bar", + "test_key=test_value", + "default_time=42", + "enable=True,disable=False,complex_list=v1;v2;v3", + ), + )