Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit d0be22e

Browse files
authored
LightningCLI and RayPlugin compatibility (#164)
Proper fix for #151 removing what #154 added.
1 parent 761ab33 commit d0be22e

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ jobs:
131131
pushd ray_lightning/tests
132132
python -m pytest -v --durations=0 -x test_ddp.py
133133
python -m pytest -v --durations=0 -x test_ddp_sharded.py
134+
python -m pytest -v --durations=0 -x test_lightning_cli.py
134135
135136
test_linux_ray_release_2:
136137
runs-on: ubuntu-latest
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
from importlib.util import find_spec
3+
from pytorch_lightning.utilities.cli import LightningCLI
4+
from ray_lightning import RayStrategy
5+
from ray_lightning.tests.utils import BoringModel
6+
from unittest import mock
7+
8+
9+
@pytest.mark.skipif(
10+
not find_spec("jsonargparse"), reason="jsonargparse required")
11+
def test_lightning_cli_raystrategy_instantiation():
12+
init_args = {
13+
"num_workers": 4, # Resolve from RayStrategy.__init__
14+
"use_gpu": False, # Resolve from RayStrategy.__init__
15+
"bucket_cap_mb": 50, # Resolve from DistributedDataParallel.__init__
16+
}
17+
cli_args = ["--trainer.strategy=RayStrategy"]
18+
cli_args += [f"--trainer.strategy.{k}={v}" for k, v in init_args.items()]
19+
20+
with mock.patch("sys.argv", ["any.py"] + cli_args):
21+
cli = LightningCLI(BoringModel, run=False)
22+
23+
assert isinstance(cli.config_init["trainer"]["strategy"], RayStrategy)
24+
assert {
25+
k: cli.config["trainer"]["strategy"]["init_args"][k]
26+
for k in init_args
27+
} == init_args

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ torch==1.12.0
1111
torchmetrics
1212
torchvision
1313
protobuf<=3.20.1
14+
jsonargparse>=4.13.2

0 commit comments

Comments
 (0)