Skip to content

Commit da59cc2

Browse files
committed
Add script_args parameter to ContainerRunner.run
Related to #57
1 parent d02af45 commit da59cc2

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

test/test_core.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
@patch("xcengine.core.ScriptCreator.__init__")
3131
@pytest.mark.parametrize("tag", [None, "bar"])
32-
@pytest.mark.parametrize("env_file_name", ["environment.yml", "foo.yaml", None])
32+
@pytest.mark.parametrize(
33+
"env_file_name", ["environment.yml", "foo.yaml", None]
34+
)
3335
@pytest.mark.parametrize("use_env_file_param", [False, True])
3436
def test_image_builder_init(
3537
init_mock,
@@ -56,7 +58,11 @@ def test_image_builder_init(
5658
)
5759
assert ib.notebook == nb_path
5860
assert ib.build_dir == build_path
59-
expected_env = environment_path if (use_env_file_param or env_file_name == "environment.yml") else None
61+
expected_env = (
62+
environment_path
63+
if (use_env_file_param or env_file_name == "environment.yml")
64+
else None
65+
)
6066
assert ib.environment == expected_env
6167
if tag is None:
6268
assert abs(
@@ -123,6 +129,28 @@ def test_runner_run_keep(keep: bool):
123129
container.remove.assert_called_once_with(force=True)
124130

125131

132+
def test_runner_extra_args():
133+
runner = xcengine.core.ContainerRunner(
134+
image := Mock(docker.models.images.Image),
135+
None,
136+
client := Mock(DockerClient),
137+
)
138+
image.tags = []
139+
client.containers.run.return_value = (container := MagicMock(Container))
140+
container.status = "exited"
141+
script_args = ["--foo", "--bar", "42", "--baz", "somestring"]
142+
runner.run(
143+
run_batch=False,
144+
host_port=None,
145+
from_saved=False,
146+
keep=False,
147+
script_args=script_args,
148+
)
149+
run_args = client.containers.run.call_args
150+
command = run_args[1]["command"]
151+
assert command == ["python", "execute.py"] + script_args
152+
153+
126154
def test_runner_sigint():
127155
runner = xcengine.core.ContainerRunner(
128156
image := Mock(docker.models.images.Image),

xcengine/core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
LOGGER = logging.getLogger(__name__)
3434
logging.basicConfig(level=logging.INFO)
3535

36+
3637
class ScriptCreator:
3738
"""Turn a Jupyter notebook into a set of scripts"""
3839

@@ -221,7 +222,7 @@ def __init__(
221222
self.environment = notebook.parent / nb_env
222223
else:
223224
LOGGER.info(f"No environment specified in notebook.")
224-
LOGGER.info(f"Looking for a file named \"environment.yml\".")
225+
LOGGER.info(f'Looking for a file named "environment.yml".')
225226
notebook_sibling = notebook.parent / "environment.yml"
226227
if notebook_sibling.is_file():
227228
self.environment = notebook_sibling
@@ -375,6 +376,7 @@ def run(
375376
host_port: int | None,
376377
from_saved: bool,
377378
keep: bool,
379+
script_args: list[str] | None = None,
378380
):
379381
LOGGER.info(f"Running container from image {self.image.short_id}")
380382
LOGGER.info(f"Image tags: {' '.join(self.image.tags)}")
@@ -391,6 +393,9 @@ def run(
391393
else []
392394
)
393395
+ (["--from-saved"] if from_saved else [])
396+
+ script_args
397+
if script_args is not None
398+
else []
394399
)
395400
run_args: dict[str, Any] = dict(
396401
image=self.image, command=command, remove=False, detach=True
@@ -400,10 +405,14 @@ def run(
400405
container: Container = self.client.containers.run(**run_args)
401406
LOGGER.info(f"Waiting for container {container.short_id} to complete.")
402407
default_sigint_handler = signal.getsignal(signal.SIGINT)
408+
403409
def signal_hander(signum, frame):
404410
signal.signal(signal.SIGINT, default_sigint_handler)
405-
LOGGER.info(f"Caught SIGINT. Stopping container {container.short_id}")
411+
LOGGER.info(
412+
f"Caught SIGINT. Stopping container {container.short_id}"
413+
)
406414
container.stop()
415+
407416
signal.signal(signal.SIGINT, signal_hander)
408417
while container.status in {"created", "running"}:
409418
LOGGER.debug(

0 commit comments

Comments
 (0)