Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit afa2389

Browse files
authored
Fairscale Integration (#42)
* wip * pass in ddp_kwargs * compat * format * fixes * updates * example * readme * remove hvd.init * fix merge * upgrade flake8
1 parent 0a47405 commit afa2389

File tree

11 files changed

+425
-35
lines changed

11 files changed

+425
-35
lines changed

.github/workflows/test.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
run: |
1717
python -m pip install --upgrade pip
1818
python -m pip install codecov
19-
python -m pip install -U yapf==0.23.0 flake8==3.7.7 flake8-comprehensions flake8-quotes==2.0.0
19+
python -m pip install -U yapf==0.23.0 flake8==3.9.1 flake8-comprehensions flake8-quotes
2020
- name: Run format script
2121
run: |
2222
./format.sh --all
@@ -46,6 +46,7 @@ jobs:
4646
python -m pytest -v --durations=0 -x test_ddp.py
4747
python -m pytest -v --durations=0 -x test_horovod.py
4848
python -m pytest -v --durations=0 -x test_tune.py
49+
python -m pytest -v --durations=0 -x test_ddp_sharded.py
4950
5051
test_linux_ray_master_examples:
5152
runs-on: ubuntu-latest
@@ -102,6 +103,7 @@ jobs:
102103
python -m pytest -v --durations=0 -x test_ddp.py
103104
python -m pytest -v --durations=0 -x test_horovod.py
104105
python -m pytest -v --durations=0 -x test_tune.py
106+
python -m pytest -v --durations=0 -x test_ddp_sharded.py
105107
106108
107109
test_linux_ray_release_examples:

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Because Ray is used to launch processes, instead of the same script being called
3939
Or if you prefer to use Horovod as the distributed training protocol, use the `HorovodRayPlugin` instead.
4040

4141
```python
42-
import pytorch_lightning as ptl
42+
import pytorch_lightning as pl
4343
from ray_lightning import HorovodRayPlugin
4444

4545
# Create your PyTorch Lightning model here.
@@ -54,12 +54,32 @@ trainer = pl.Trainer(..., gpus=1, plugins=[plugin])
5454
trainer.fit(ptl_model)
5555
```
5656

57+
## Model Parallel Sharded Training on Ray
58+
The `RayShardedPlugin` integrates with [FairScale](https://github.com/facebookresearch/fairscale) to provide sharded DDP training on a Ray cluster.
59+
With sharded training, leverage the scalability of data parallel training while drastically reducing memory usage when training large models.
60+
61+
```python
62+
import pytorch_lightning as pl
63+
from ray_lightning import RayShardedPlugin
64+
65+
# Create your PyTorch Lightning model here.
66+
ptl_model = MNISTClassifier(...)
67+
plugin = RayShardedPlugin(num_workers=4, cpus_per_worker=1, use_gpu=True)
68+
69+
# If using GPUs, set the ``gpus`` arg to a value > 0.
70+
# The actual number of GPUs is determined by ``num_workers``.
71+
trainer = pl.Trainer(..., gpus=1, plugins=[plugin])
72+
trainer.fit(ptl_model)
73+
```
74+
See the [Pytorch Lightning docs](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#sharded-training) for more information on sharded training.
75+
5776
## Multi-node Distributed Training
5877
Using the same examples above, you can run distributed training on a multi-node cluster with just 2 simple steps.
5978
1) [Use Ray's cluster launcher](https://docs.ray.io/en/master/cluster/launcher.html) to start a Ray cluster- `ray up my_cluster_config.yaml`.
6079
2) [Execute your Python script on the Ray cluster](https://docs.ray.io/en/master/cluster/commands.html#running-ray-scripts-on-the-cluster-ray-submit)- `ray submit my_cluster_config.yaml train.py`. This will `rsync` your training script to the head node, and execute it on the Ray cluster.
6180

6281
You no longer have to set environment variables or configurations and run your training script on every single node.
82+
6383
## Hyperparameter Tuning with Ray Tune
6484
`ray_lightning` also integrates with Ray Tune to provide distributed hyperparameter tuning for your distributed model training. You can run multiple PyTorch Lightning training runs in parallel, each with a different hyperparameter configuration, and each training run parallelized by itself. All you have to do is move your training code to a function, pass the function to tune.run, and make sure to add the appropriate callback (Either `TuneReportCallback` or `TuneReportCheckpointCallback`) to your PyTorch Lightning Trainer.
6585

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import os
2+
import tempfile
3+
import time
4+
5+
import ray
6+
import torch
7+
from pl_bolts.datamodules import MNISTDataModule
8+
from pl_bolts.models.vision import ImageGPT
9+
10+
import pytorch_lightning as pl
11+
from pytorch_lightning import Callback
12+
13+
from ray_lightning import RayShardedPlugin
14+
15+
16+
class CUDACallback(Callback):
17+
def on_train_epoch_start(self, trainer, pl_module):
18+
# Reset the memory use counter
19+
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
20+
torch.cuda.synchronize(trainer.root_gpu)
21+
self.start_time = time.time()
22+
23+
def on_train_epoch_end(self, trainer, pl_module, outputs):
24+
torch.cuda.synchronize(trainer.root_gpu)
25+
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
26+
epoch_time = time.time() - self.start_time
27+
28+
max_memory = torch.tensor(
29+
max_memory, dtype=torch.int, device=trainer.root_gpu)
30+
epoch_time = torch.tensor(
31+
epoch_time, dtype=torch.int, device=trainer.root_gpu)
32+
33+
torch.distributed.all_reduce(
34+
max_memory, op=torch.distributed.ReduceOp.SUM)
35+
torch.distributed.all_reduce(
36+
epoch_time, op=torch.distributed.ReduceOp.SUM)
37+
38+
world_size = torch.distributed.get_world_size()
39+
40+
print(
41+
f"Average Epoch time: {epoch_time.item() / float(world_size):.2f} "
42+
f"seconds")
43+
print(
44+
f"Average Peak memory {max_memory.item() / float(world_size):.2f}"
45+
f"MiB")
46+
47+
48+
def train(data_dir, num_workers, use_gpu, batch_size, embed_dim, max_epochs,
49+
max_steps):
50+
# Make sure data is downloaded on all nodes.
51+
def download_data():
52+
from filelock import FileLock
53+
with FileLock(os.path.join(data_dir, ".lock")):
54+
MNISTDataModule(data_dir=data_dir).prepare_data()
55+
56+
plugin = RayShardedPlugin(
57+
num_workers=num_workers, use_gpu=use_gpu, init_hook=download_data)
58+
59+
dm = MNISTDataModule(data_dir, batch_size=batch_size)
60+
61+
model = ImageGPT(
62+
embed_dim=embed_dim, layers=16, heads=4, vocab_size=32, num_pixels=28)
63+
64+
trainer = pl.Trainer(
65+
max_epochs=max_epochs,
66+
gpus=int(use_gpu),
67+
precision=16 if use_gpu else 32,
68+
callbacks=[CUDACallback()] if use_gpu else [],
69+
plugins=plugin,
70+
max_steps=max_steps)
71+
72+
trainer.fit(model, dm)
73+
74+
75+
if __name__ == "__main__":
76+
import argparse
77+
78+
parser = argparse.ArgumentParser()
79+
parser.add_argument(
80+
"--num-workers",
81+
type=int,
82+
help="Number of training workers to use.",
83+
default=1)
84+
parser.add_argument(
85+
"--use-gpu", action="store_true", help="Use GPU for training.")
86+
parser.add_argument(
87+
"--num-epochs",
88+
type=int,
89+
default=10,
90+
help="Number of epochs to train for.")
91+
parser.add_argument(
92+
"--batch-size",
93+
type=int,
94+
default=4,
95+
help="Batch size to use for training.")
96+
parser.add_argument(
97+
"--embed-dim",
98+
type=int,
99+
default=2048,
100+
help="Number of embedding dimensions for ImageGPT model.")
101+
parser.add_argument(
102+
"--smoke-test", action="store_true", help="Finish quickly for testing")
103+
parser.add_argument(
104+
"--address",
105+
required=False,
106+
type=str,
107+
help="the address to use for Ray")
108+
args, _ = parser.parse_known_args()
109+
110+
if args.smoke_test:
111+
ray.init(num_cpus=2)
112+
else:
113+
ray.init(address=args.address)
114+
115+
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
116+
117+
if args.smoke_test:
118+
train(
119+
data_dir=data_dir,
120+
num_workers=2,
121+
use_gpu=False,
122+
batch_size=32,
123+
embed_dim=16,
124+
max_epochs=1,
125+
max_steps=1)
126+
else:
127+
train(
128+
data_dir=data_dir,
129+
num_workers=args.num_workers,
130+
use_gpu=args.use_gpu,
131+
batch_size=args.batch_size,
132+
embed_dim=args.embed_dim,
133+
max_epochs=args.num_epochs,
134+
max_steps=None)

ray_lightning/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ray_lightning.ray_ddp import RayPlugin
22
from ray_lightning.ray_horovod import HorovodRayPlugin
3+
from ray_lightning.ray_ddp_sharded import RayShardedPlugin
34

4-
__all__ = ["RayPlugin", "HorovodRayPlugin"]
5+
__all__ = ["RayPlugin", "HorovodRayPlugin", "RayShardedPlugin"]

ray_lightning/ray_ddp.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,7 @@ def get_local_ranks(self) -> Dict[int, int]:
146146
rank_counter_dict[ip] += 1
147147
return global_to_local
148148

149-
def start_training(self, trainer):
150-
"""Main training loop.
151-
152-
Sets up the torch.distributed process group for each training
153-
worker. Then trigger remote training via ``train_remote`` on each
154-
worker. If using with Ray Tune, create a communication queue to
155-
revieve intermediate results, and process those results. Finally
156-
retrieve the training results from the rank 0 worker and return."""
157-
149+
def _setup_env_vars(self):
158150
# Get rank 0 worker address and port for DDP connection.
159151
os.environ["MASTER_ADDR"] = ray.get(
160152
self.workers[0].get_node_ip.remote())
@@ -169,6 +161,19 @@ def start_training(self, trainer):
169161
values = [os.getenv(k) for k in keys]
170162
ray.get([w.set_env_vars.remote(keys, values) for w in self.workers])
171163

164+
def execution_loop(self, trainer, tune_enabled: bool = True):
165+
"""Main execution loop for training, testing, & prediction.
166+
167+
Sets up the torch.distributed process group for each
168+
worker. Then trigger remote training/testing/eval via
169+
``train_remote`` on each worker. If using with Ray Tune, create a
170+
communication queue to retrieve intermediate results, and process
171+
those results. Finally retrieve the training results from the rank 0
172+
worker and return."""
173+
174+
# Sets environment variables for all workers.
175+
self._setup_env_vars()
176+
172177
self.global_to_local = self.get_local_ranks()
173178

174179
model = self._model
@@ -177,12 +182,12 @@ def start_training(self, trainer):
177182
self._model = None
178183

179184
queue = None
180-
if TUNE_INSTALLED and is_session_enabled():
185+
if tune_enabled and TUNE_INSTALLED and is_session_enabled():
181186
# Create communication queue and send to all the workers.
182187
queue = Queue(actor_options={"num_cpus": 0})
183188

184189
futures = [
185-
self.workers[i].execute.remote(self.train_remote, model_ref, i,
190+
self.workers[i].execute.remote(self.execute_remote, model_ref, i,
186191
queue)
187192
for i in range(self.num_workers)
188193
]
@@ -195,7 +200,7 @@ def start_training(self, trainer):
195200
self._model = model
196201
self._model.load_state_dict(state_dict)
197202
if self.lightning_module.trainer.checkpoint_callback:
198-
self.lightning_module.trainer.checkpoint_callback\
203+
self.lightning_module.trainer.checkpoint_callback \
199204
.best_model_path = best_path
200205

201206
if queue:
@@ -204,6 +209,21 @@ def start_training(self, trainer):
204209

205210
return results
206211

212+
def start_training(self, trainer):
213+
results = self.execution_loop(trainer, tune_enabled=True)
214+
# reset optimizers, since main process is never used for training and
215+
# thus does not have a valid optim state.
216+
trainer.optimizers = []
217+
return results
218+
219+
def start_testing(self, trainer):
220+
results = self.execution_loop(trainer, tune_enabled=False)
221+
return results
222+
223+
def start_predicting(self, trainer):
224+
results = self.execution_loop(trainer, tune_enabled=False)
225+
return results
226+
207227
def post_dispatch(self):
208228
"""Shutdown the DDP process group and all the Ray actors. """
209229

@@ -220,18 +240,19 @@ def shutdown_remote():
220240

221241
# All methods below are only executed in remote Ray workers.
222242

223-
def train_remote(self,
224-
model: LightningModule,
225-
global_rank: int,
226-
queue: Queue = None):
227-
"""Training function to be executed on each remote worker."""
243+
def execute_remote(self,
244+
model: LightningModule,
245+
global_rank: int,
246+
queue: Queue = None):
247+
"""Train/test/eval function to be executed on each remote worker."""
228248
assert isinstance(self, RayPlugin)
229249
# This method should be executed remotely in each worker.
230250
self._model = model
231251
self.lightning_module.trainer.accelerator_connector\
232252
._training_type_plugin = self
233253
self.lightning_module.trainer.accelerator.training_type_plugin = self
234254
self.cluster_environment.set_global_rank(global_rank)
255+
self.cluster_environment.set_remote_execution(True)
235256

236257
if queue is not None:
237258
# Initialize session.
@@ -272,7 +293,7 @@ def init_ddp_connection(self,
272293
def set_world_ranks(self, process_idx: int = 0):
273294
"""Set the appropriate rank attribues for the trainer."""
274295
assert self.cluster_environment is not None
275-
if self.global_rank is not None:
296+
if self.cluster_environment.is_remote():
276297
self._local_rank = self.global_to_local[self.global_rank]
277298
self.cluster_environment.set_global_rank(self.global_rank)
278299
self.cluster_environment.set_world_size(self.num_workers)

ray_lightning/ray_ddp_sharded.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch.optim import Optimizer
5+
6+
from pytorch_lightning import LightningModule
7+
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
8+
9+
from ray_lightning import RayPlugin
10+
11+
if _FAIRSCALE_AVAILABLE:
12+
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
13+
from fairscale.optim import OSS
14+
15+
from pytorch_lightning.overrides.fairscale import \
16+
LightningShardedDataParallel, unwrap_lightning_module_sharded
17+
18+
19+
class RayShardedPlugin(RayPlugin):
20+
def configure_ddp(self):
21+
self._wrap_optimizers()
22+
self._model = ShardedDataParallel(
23+
LightningShardedDataParallel(self.model),
24+
sharded_optimizer=self.lightning_module.trainer.optimizers)
25+
setattr(self._model, "require_backward_grad_sync", False)
26+
27+
def _reinit_optimizers_with_oss(self):
28+
optimizers = self.lightning_module.trainer.optimizers
29+
for x, optimizer in enumerate(optimizers):
30+
if not isinstance(optimizer, OSS):
31+
optim_class = type(optimizer)
32+
zero_optimizer = OSS(
33+
params=optimizer.param_groups,
34+
optim=optim_class,
35+
**optimizer.defaults)
36+
optimizers[x] = zero_optimizer
37+
del optimizer
38+
trainer = self.lightning_module.trainer
39+
trainer.optimizers = optimizers
40+
41+
def _wrap_optimizers(self):
42+
trainer = self.model.trainer
43+
if trainer.testing:
44+
return
45+
self._reinit_optimizers_with_oss()
46+
47+
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
48+
if isinstance(optimizer, OSS):
49+
optimizer.consolidate_state_dict()
50+
return self._optim_state_dict(optimizer)
51+
52+
@rank_zero_only
53+
def _optim_state_dict(self, optimizer):
54+
"""Retrieves state dict only on rank 0."""
55+
return optimizer.state_dict()
56+
57+
@property
58+
def lightning_module(self) -> LightningModule:
59+
return unwrap_lightning_module_sharded(self._model)
60+
61+
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool,
62+
optimizer: Optimizer, opt_idx: int):
63+
pass
64+
65+
def post_training_step(self):
66+
pass

0 commit comments

Comments
 (0)