Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 41e3491

Browse files
authored
Support PTL 1.4 (#58)
* support 1.3.8 * increase timeout * upgrade * fix failing test * split test suite * fix * bump up timeouts more
1 parent 4300391 commit 41e3491

File tree

6 files changed

+74
-79
lines changed

6 files changed

+74
-79
lines changed

.github/workflows/test.yaml

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ jobs:
2020
- name: Run format script
2121
run: |
2222
./format.sh --all
23-
test_linux_ray_master:
23+
24+
test_linux_ray_master_1:
2425
runs-on: ubuntu-latest
25-
timeout-minutes: 25
26+
timeout-minutes: 40
2627
steps:
2728
- uses: actions/checkout@v2
2829
- name: Set up Python 3.7
@@ -44,13 +45,37 @@ jobs:
4445
run: |
4546
pushd ray_lightning/tests
4647
python -m pytest -v --durations=0 -x test_ddp.py
48+
python -m pytest -v --durations=0 -x test_ddp_sharded.py
49+
50+
test_linux_ray_master_2:
51+
runs-on: ubuntu-latest
52+
timeout-minutes: 40
53+
steps:
54+
- uses: actions/checkout@v2
55+
- name: Set up Python 3.7
56+
uses: actions/setup-python@v2
57+
with:
58+
python-version: 3.7
59+
- name: Install dependencies
60+
run: |
61+
python -m pip install --upgrade pip
62+
python -m pip install --upgrade setuptools
63+
python -m pip install codecov
64+
python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl
65+
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
66+
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install git+https://github.com/horovod/horovod.git
67+
- name: Install package
68+
run: |
69+
python -m pip install -e .
70+
- name: Test with Pytest
71+
run: |
72+
pushd ray_lightning/tests
4773
python -m pytest -v --durations=0 -x test_horovod.py
4874
python -m pytest -v --durations=0 -x test_tune.py
49-
python -m pytest -v --durations=0 -x test_ddp_sharded.py
5075
5176
test_linux_ray_master_examples:
5277
runs-on: ubuntu-latest
53-
timeout-minutes: 25
78+
timeout-minutes: 40
5479
steps:
5580
- uses: actions/checkout@v2
5681
- name: Set up Python 3.7
@@ -83,9 +108,9 @@ jobs:
83108
echo "running examples with Ray Client 3" && python -m pytest -v --durations=0 -x test_client_3.py
84109
85110
86-
test_linux_ray_release:
111+
test_linux_ray_release_1:
87112
runs-on: ubuntu-latest
88-
timeout-minutes: 25
113+
timeout-minutes: 40
89114
steps:
90115
- uses: actions/checkout@v2
91116
- name: Set up Python 3.7
@@ -107,14 +132,38 @@ jobs:
107132
run: |
108133
pushd ray_lightning/tests
109134
python -m pytest -v --durations=0 -x test_ddp.py
135+
python -m pytest -v --durations=0 -x test_ddp_sharded.py
136+
137+
test_linux_ray_release_2:
138+
runs-on: ubuntu-latest
139+
timeout-minutes: 40
140+
steps:
141+
- uses: actions/checkout@v2
142+
- name: Set up Python 3.7
143+
uses: actions/setup-python@v2
144+
with:
145+
python-version: 3.7
146+
- name: Install dependencies
147+
run: |
148+
python -m pip install --upgrade pip
149+
python -m pip install --upgrade setuptools
150+
python -m pip install codecov
151+
python -m pip install -U ray
152+
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
153+
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git
154+
- name: Install package
155+
run: |
156+
python -m pip install -e .
157+
- name: Test with Pytest
158+
run: |
159+
pushd ray_lightning/tests
110160
python -m pytest -v --durations=0 -x test_horovod.py
111161
python -m pytest -v --durations=0 -x test_tune.py
112-
python -m pytest -v --durations=0 -x test_ddp_sharded.py
113162
114163
115164
test_linux_ray_release_examples:
116165
runs-on: ubuntu-latest
117-
timeout-minutes: 25
166+
timeout-minutes: 40
118167
steps:
119168
- uses: actions/checkout@v2
120169
- name: Set up Python 3.7

ray_lightning/ray_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def start_training(self, trainer):
216216
trainer.optimizers = []
217217
return results
218218

219-
def start_testing(self, trainer):
219+
def start_evaluating(self, trainer):
220220
results = self.execution_loop(trainer, tune_enabled=False)
221221
return results
222222

ray_lightning/ray_ddp_sharded.py

Lines changed: 3 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,7 @@
1-
from typing import Optional
2-
3-
import torch
4-
from torch.optim import Optimizer
5-
6-
from pytorch_lightning import LightningModule
7-
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
1+
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
82

93
from ray_lightning import RayPlugin
104

11-
if _FAIRSCALE_AVAILABLE:
12-
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
13-
from fairscale.optim import OSS
14-
15-
from pytorch_lightning.overrides.fairscale import \
16-
LightningShardedDataParallel, unwrap_lightning_module_sharded
17-
18-
19-
class RayShardedPlugin(RayPlugin):
20-
def configure_ddp(self):
21-
self._wrap_optimizers()
22-
self._model = ShardedDataParallel(
23-
LightningShardedDataParallel(self.model),
24-
sharded_optimizer=self.lightning_module.trainer.optimizers)
25-
setattr(self._model, "require_backward_grad_sync", False)
26-
27-
def _reinit_optimizers_with_oss(self):
28-
optimizers = self.lightning_module.trainer.optimizers
29-
for x, optimizer in enumerate(optimizers):
30-
if not isinstance(optimizer, OSS):
31-
optim_class = type(optimizer)
32-
zero_optimizer = OSS(
33-
params=optimizer.param_groups,
34-
optim=optim_class,
35-
**optimizer.defaults)
36-
optimizers[x] = zero_optimizer
37-
del optimizer
38-
trainer = self.lightning_module.trainer
39-
trainer.optimizers = optimizers
40-
41-
def _wrap_optimizers(self):
42-
trainer = self.model.trainer
43-
if trainer.testing:
44-
return
45-
self._reinit_optimizers_with_oss()
46-
47-
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
48-
if isinstance(optimizer, OSS):
49-
optimizer.consolidate_state_dict()
50-
return self._optim_state_dict(optimizer)
51-
52-
@rank_zero_only
53-
def _optim_state_dict(self, optimizer):
54-
"""Retrieves state dict only on rank 0."""
55-
return optimizer.state_dict()
56-
57-
@property
58-
def lightning_module(self) -> LightningModule:
59-
return unwrap_lightning_module_sharded(self._model)
60-
61-
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool,
62-
optimizer: Optimizer, opt_idx: int):
63-
pass
645

65-
def post_training_step(self):
66-
pass
6+
class RayShardedPlugin(RayPlugin, DDPSpawnShardedPlugin):
7+
pass

ray_lightning/tests/test_ddp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,22 @@ def test_early_stop(tmpdir, ray_start_2_cpus):
154154
"""Tests if early stopping callback works correctly."""
155155
model = BoringModel()
156156
plugin = RayPlugin(num_workers=1, use_gpu=False)
157-
early_stop = EarlyStopping(monitor="val_loss", patience=2, verbose=True)
157+
patience = 2
158+
early_stop = EarlyStopping(
159+
monitor="val_loss", patience=patience, verbose=True)
158160
trainer = get_trainer(
159161
tmpdir,
160162
max_epochs=500,
161163
plugins=[plugin],
162164
callbacks=[early_stop],
165+
num_sanity_val_steps=0,
163166
limit_train_batches=1.0,
164167
limit_val_batches=1.0,
165168
progress_bar_refresh_rate=1)
166169
trainer.fit(model)
167170
trained_model = BoringModel.load_from_checkpoint(
168171
trainer.checkpoint_callback.best_model_path)
169-
assert trained_model.val_epoch == 2, trained_model.val_epoch
172+
assert trained_model.val_epoch == patience + 1, trained_model.val_epoch
170173

171174

172175
def test_unused_parameters(tmpdir, ray_start_2_cpus):

ray_lightning/tests/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,30 +153,32 @@ def get_trainer(dir,
153153
limit_val_batches: int = 10,
154154
progress_bar_refresh_rate: int = 0,
155155
callbacks: Optional[List[Callback]] = None,
156-
checkpoint_callback: bool = True) -> Trainer:
156+
checkpoint_callback: bool = True,
157+
**trainer_kwargs) -> Trainer:
157158
"""Returns a Pytorch Lightning Trainer with the provided arguments."""
158159
callbacks = [] if not callbacks else callbacks
159160
trainer = pl.Trainer(
160161
default_root_dir=dir,
161162
gpus=1 if use_gpu else 0,
163+
callbacks=callbacks,
164+
plugins=plugins,
162165
max_epochs=max_epochs,
163166
limit_train_batches=limit_train_batches,
164167
limit_val_batches=limit_val_batches,
165168
progress_bar_refresh_rate=progress_bar_refresh_rate,
166169
checkpoint_callback=checkpoint_callback,
167-
callbacks=callbacks,
168-
plugins=plugins)
170+
**trainer_kwargs)
169171
return trainer
170172

171173

172174
def train_test(trainer: Trainer, model: LightningModule):
173175
"""Checks if training the provided model updates its weights."""
174176
initial_values = torch.tensor(
175177
[torch.sum(torch.abs(x)) for x in model.parameters()])
176-
result = trainer.fit(model)
178+
trainer.fit(model)
177179
post_train_values = torch.tensor(
178180
[torch.sum(torch.abs(x)) for x in model.parameters()])
179-
assert result == 1, "trainer failed"
181+
assert trainer.state.finished, f"Trainer failed with {trainer.state}"
180182
# Check that the model is actually changed post-training.
181183
assert torch.norm(initial_values - post_train_values) > 0.1
182184

requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ flake8-comprehensions
44
flake8-quotes
55
yapf==0.23.0
66
pytest
7-
pytorch-lightning==1.2.10
7+
pytorch-lightning==1.4.1
88
lightning-bolts==0.3.3
99
ray[tune]
1010
torch==1.8.1

0 commit comments

Comments
 (0)