|
15 | 15 | """
|
16 | 16 |
|
17 | 17 | import os
|
| 18 | +import shutil |
| 19 | +import tempfile |
| 20 | +import time |
18 | 21 | import unittest
|
19 | 22 | from types import ModuleType
|
| 23 | +from typing import Any, Callable, Dict, Optional |
| 24 | + |
| 25 | +from torchx.runner import get_runner |
| 26 | + |
| 27 | +from torchx.specs import AppDef, AppStatus |
20 | 28 |
|
21 | 29 | from torchx.specs.builders import _create_args_parser
|
22 | 30 | from torchx.specs.finder import get_component
|
23 | 31 |
|
24 | 32 |
|
25 | 33 | class ComponentTestCase(unittest.TestCase):
|
| 34 | + def setUp(self) -> None: |
| 35 | + self.test_dir = tempfile.mkdtemp("torchx_component_test") |
| 36 | + |
| 37 | + self.old_cwd = os.getcwd() |
| 38 | + os.chdir(os.path.dirname(os.path.dirname(__file__))) |
| 39 | + |
| 40 | + def tearDown(self) -> None: |
| 41 | + shutil.rmtree(self.test_dir) |
| 42 | + os.chdir(self.old_cwd) |
| 43 | + |
26 | 44 | """
|
27 | 45 | ComponentTestCase is an extension of TestCase with helper methods for use
|
28 | 46 | with testing component definitions.
|
@@ -58,3 +76,43 @@ def validate(self, module: ModuleType, function_name: str) -> None:
|
58 | 76 | # this will raise an exception and the test will fail
|
59 | 77 | with self.assertRaises(SystemExit):
|
60 | 78 | _ = _create_args_parser(component_def.fn).parse_args(["--help"])
|
| 79 | + |
| 80 | + def run_component( |
| 81 | + self, |
| 82 | + component: Callable[..., AppDef], |
| 83 | + args: Optional[Dict[str, Any]] = None, |
| 84 | + scheduler_params: Optional[Dict[str, Any]] = None, |
| 85 | + scheduler: str = "local_cwd", |
| 86 | + interval: float = 0.1, |
| 87 | + timeout: float = 1, |
| 88 | + ) -> Optional[AppStatus]: |
| 89 | + """ |
| 90 | + Helper function that hides complexity of setting up the runner and polling results. |
| 91 | + Note: method is blocking until either scheduler exits or timeout is reached (for non-blocking schedulers). |
| 92 | +
|
| 93 | + Args: |
| 94 | + components: component function, factory for AppDef |
| 95 | + args: optional component factory arguments |
| 96 | + scheduler_params: optional parameters for scheduler factory method |
| 97 | + scheduler: scheduler name |
| 98 | + interval: scheduler comppletion polling interval |
| 99 | + timeout: max time for scheduler to complete |
| 100 | +
|
| 101 | + """ |
| 102 | + |
| 103 | + app_def = component(**args) |
| 104 | + if scheduler_params: |
| 105 | + runner = get_runner(name=None, component_defaults=None, **scheduler_params) |
| 106 | + else: |
| 107 | + runner = get_runner(name=None, component_defaults=None) |
| 108 | + |
| 109 | + app_handle = runner.run(app_def, scheduler) |
| 110 | + |
| 111 | + elapsed = 0 |
| 112 | + status = runner.status(app_handle) |
| 113 | + while (status and not status.is_terminal) or elapsed < timeout: |
| 114 | + time.sleep(interval) |
| 115 | + elapsed += interval |
| 116 | + status = runner.status(app_handle) |
| 117 | + |
| 118 | + return status |
0 commit comments