Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 4dcbdf0

Browse files
authored
Custom Accelerator so driver doesn't require GPU (#67)
* wip * wip * support 1.3.8 * increase timeout * upgrade * fix failing test * wip * update * add comment * fix * remove server address
1 parent 41e3491 commit 4dcbdf0

File tree

8 files changed

+71
-115
lines changed

8 files changed

+71
-115
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ from ray_lightning import RayPlugin
2525
ptl_model = MNISTClassifier(...)
2626
plugin = RayPlugin(num_workers=4, cpus_per_worker=1, use_gpu=True)
2727

28-
# If using GPUs, set the ``gpus`` arg to a value > 0.
28+
# Don't set ``gpus`` in the ``Trainer``.
2929
# The actual number of GPUs is determined by ``num_workers``.
30-
trainer = pl.Trainer(..., gpus=1, plugins=[plugin])
30+
trainer = pl.Trainer(..., plugins=[plugin])
3131
trainer.fit(ptl_model)
3232
```
3333

@@ -48,9 +48,9 @@ ptl_model = MNISTClassifier(...)
4848
# 2 nodes, 4 workers per node, each using 1 CPU and 1 GPU.
4949
plugin = HorovodRayPlugin(num_hosts=2, num_slots=4, use_gpu=True)
5050

51-
# If using GPUs, set the ``gpus`` arg to a value > 0.
51+
# Don't set ``gpus`` in the ``Trainer``.
5252
# The actual number of GPUs is determined by ``num_slots``.
53-
trainer = pl.Trainer(..., gpus=1, plugins=[plugin])
53+
trainer = pl.Trainer(..., plugins=[plugin])
5454
trainer.fit(ptl_model)
5555
```
5656

@@ -66,9 +66,9 @@ from ray_lightning import RayShardedPlugin
6666
ptl_model = MNISTClassifier(...)
6767
plugin = RayShardedPlugin(num_workers=4, cpus_per_worker=1, use_gpu=True)
6868

69-
# If using GPUs, set the ``gpus`` arg to a value > 0.
69+
# Don't set ``gpus`` in the ``Trainer``.
7070
# The actual number of GPUs is determined by ``num_workers``.
71-
trainer = pl.Trainer(..., gpus=1, plugins=[plugin])
71+
trainer = pl.Trainer(..., plugins=[plugin])
7272
trainer.fit(ptl_model)
7373
```
7474
See the [Pytorch Lightning docs](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#sharded-training) for more information on sharded training.

ray_lightning/examples/ray_ddp_example.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def train_mnist(config,
7272

7373
trainer = pl.Trainer(
7474
max_epochs=num_epochs,
75-
gpus=int(use_gpu),
7675
callbacks=callbacks,
7776
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
7877
**trainer_kwargs)
@@ -111,7 +110,6 @@ def tune_mnist(data_dir,
111110
num_samples=num_samples,
112111
resources_per_trial={
113112
"cpu": 1,
114-
"gpu": int(use_gpu),
115113
"extra_cpu": num_workers,
116114
"extra_gpu": num_workers * int(use_gpu)
117115
},
@@ -152,11 +150,6 @@ def tune_mnist(data_dir,
152150
required=False,
153151
type=str,
154152
help="the address to use for Ray")
155-
parser.add_argument(
156-
"--server-address",
157-
required=False,
158-
type=str,
159-
help="If using Ray Client, the address of the server to connect to. ")
160153
args, _ = parser.parse_known_args()
161154

162155
num_epochs = 1 if args.smoke_test else args.num_epochs
@@ -166,8 +159,6 @@ def tune_mnist(data_dir,
166159

167160
if args.smoke_test:
168161
ray.init(num_cpus=2)
169-
elif args.server_address:
170-
ray.util.connect(args.server_address)
171162
else:
172163
ray.init(address=args.address)
173164

ray_lightning/examples/ray_ddp_sharded_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def download_data():
6363

6464
trainer = pl.Trainer(
6565
max_epochs=max_epochs,
66-
gpus=int(use_gpu),
6766
precision=16 if use_gpu else 32,
6867
callbacks=[CUDACallback()] if use_gpu else [],
6968
plugins=plugin,

ray_lightning/examples/ray_ddp_tune.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def download_data():
3030

3131
trainer = pl.Trainer(
3232
max_epochs=num_epochs,
33-
gpus=int(use_gpu),
3433
callbacks=callbacks,
3534
progress_bar_refresh_rate=0,
3635
plugins=[
@@ -74,7 +73,6 @@ def tune_mnist(data_dir,
7473
num_samples=num_samples,
7574
resources_per_trial={
7675
"cpu": 1,
77-
"gpu": int(use_gpu),
7876
"extra_cpu": num_workers,
7977
"extra_gpu": num_workers * int(use_gpu)
8078
},
@@ -111,11 +109,6 @@ def tune_mnist(data_dir,
111109
required=False,
112110
type=str,
113111
help="the address to use for Ray")
114-
parser.add_argument(
115-
"--server-address",
116-
required=False,
117-
type=str,
118-
help="If using Ray Client, the address of the server to connect to. ")
119112
args, _ = parser.parse_known_args()
120113

121114
num_epochs = 1 if args.smoke_test else args.num_epochs
@@ -125,8 +118,6 @@ def tune_mnist(data_dir,
125118

126119
if args.smoke_test:
127120
ray.init(num_cpus=2)
128-
elif args.server_address:
129-
ray.util.connect(args.server_address)
130121
else:
131122
ray.init(address=args.address)
132123

ray_lightning/examples/ray_horovod_example.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def train_mnist(config,
7373

7474
trainer = pl.Trainer(
7575
max_epochs=num_epochs,
76-
gpus=int(use_gpu),
7776
callbacks=callbacks,
7877
plugins=[
7978
HorovodRayPlugin(
@@ -174,11 +173,6 @@ def tune_mnist(data_dir,
174173
required=False,
175174
type=str,
176175
help="the address to use for Ray")
177-
parser.add_argument(
178-
"--server-address",
179-
required=False,
180-
type=str,
181-
help="If using Ray Client, the address of the server to connect to. ")
182176
args, _ = parser.parse_known_args()
183177

184178
num_epochs = 1 if args.smoke_test else args.num_epochs
@@ -189,8 +183,6 @@ def tune_mnist(data_dir,
189183

190184
if args.smoke_test:
191185
ray.init(num_cpus=2)
192-
elif args.server_address:
193-
ray.util.connect(args.server_address)
194186
else:
195187
ray.init(address=args.address)
196188

ray_lightning/ray_ddp.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1+
import io
12
from typing import Callable, Dict, List, Union, Any
23

34
import os
45
from collections import defaultdict
56

6-
import ray
77
import torch
8+
9+
from pytorch_lightning.accelerators import CPUAccelerator
810
from pytorch_lightning.plugins import DDPSpawnPlugin
911
from pytorch_lightning import _logger as log, LightningModule
1012
from pytorch_lightning.utilities import rank_zero_only
13+
14+
import ray
1115
from ray.util.sgd.utils import find_free_port
16+
from ray.util.queue import Queue
1217

1318
from ray_lightning.session import init_session
14-
from ray_lightning.util import process_results, Queue
19+
from ray_lightning.util import process_results
1520
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled
1621
from ray_lightning.ray_environment import RayEnvironment
1722

@@ -161,6 +166,15 @@ def _setup_env_vars(self):
161166
values = [os.getenv(k) for k in keys]
162167
ray.get([w.set_env_vars.remote(keys, values) for w in self.workers])
163168

169+
def _load_state_stream(self, state_stream):
170+
_buffer = io.BytesIO(state_stream)
171+
to_gpu = self.use_gpu and torch.cuda.is_available()
172+
state_dict = torch.load(
173+
_buffer,
174+
map_location=("cpu" if not to_gpu
175+
else lambda storage, loc: storage.cuda()))
176+
return state_dict
177+
164178
def execution_loop(self, trainer, tune_enabled: bool = True):
165179
"""Main execution loop for training, testing, & prediction.
166180
@@ -194,7 +208,8 @@ def execution_loop(self, trainer, tune_enabled: bool = True):
194208

195209
results = process_results(futures, queue)
196210
# Get the results, checkpoint path, and model weights from worker 0.
197-
results, best_path, state_dict = results[0]
211+
results, best_path, state_stream = results[0]
212+
state_dict = self._load_state_stream(state_stream)
198213
# Set the state for PTL using the output from remote training.
199214
self._results = results
200215
self._model = model
@@ -209,6 +224,24 @@ def execution_loop(self, trainer, tune_enabled: bool = True):
209224

210225
return results
211226

227+
def setup_environment(self) -> None:
228+
# Swap out the accelerator if necessary.
229+
# This is needed to support CPU head with GPU workers or Ray Client.
230+
current_accelerator = self.lightning_module.trainer.accelerator
231+
if self.use_gpu and isinstance(current_accelerator, CPUAccelerator):
232+
from weakref import proxy
233+
from ray_lightning.util import DelayedGPUAccelerator
234+
precision_plugin = current_accelerator.precision_plugin
235+
new_accelerator = DelayedGPUAccelerator(
236+
precision_plugin=precision_plugin, training_type_plugin=self)
237+
self.lightning_module.trainer.accelerator_connector\
238+
._training_type_plugin = \
239+
proxy(new_accelerator.training_type_plugin)
240+
self.lightning_module.trainer.accelerator_connector\
241+
._precision_plugin = proxy(new_accelerator.precision_plugin)
242+
self.lightning_module.trainer.accelerator_connector.accelerator \
243+
= new_accelerator
244+
212245
def start_training(self, trainer):
213246
results = self.execution_loop(trainer, tune_enabled=True)
214247
# reset optimizers, since main process is never used for training and
@@ -268,7 +301,7 @@ def execute_remote(self,
268301
mp_queue=None)
269302
# Only need results from worker 0.
270303
if self.global_rank == 0:
271-
return self.results, self.best_model_path, self.model_state_dict
304+
return self.results, self.best_model_path, self.model_state_stream
272305
else:
273306
return None
274307

@@ -307,12 +340,18 @@ def root_device(self):
307340
else:
308341
return torch.device("cpu")
309342

343+
def _to_state_stream(self, model_state_dict):
344+
_buffer = io.BytesIO()
345+
torch.save(model_state_dict, _buffer)
346+
return _buffer.getvalue()
347+
310348
def transfer_distrib_spawn_state_on_fit_end(self, results):
311349
"""Sets the training output as attributes so it can be retrieved."""
312350
if self.global_rank == 0:
313351
# Save training results as attributes.
314352
self._results = results
315-
self.model_state_dict = self.lightning_module.state_dict()
353+
self.model_state_stream = \
354+
self._to_state_stream(self.lightning_module.state_dict())
316355
best_model_path = None
317356
if self.lightning_module.trainer.checkpoint_callback is not None:
318357
best_model_path = \

ray_lightning/ray_horovod.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import ray
21
import torch
32
from pytorch_lightning import LightningModule
43
from pytorch_lightning.plugins import HorovodPlugin
54
from pytorch_lightning.utilities import rank_zero_only
5+
6+
import ray
67
from ray import ObjectRef
8+
from ray.util.queue import Queue
79

810
from ray_lightning.session import init_session
9-
from ray_lightning.util import process_results, Queue, Unavailable
11+
from ray_lightning.util import process_results, Unavailable
1012
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled
1113

1214
try:

ray_lightning/util.py

Lines changed: 17 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,30 @@
1-
# Remove after Ray 1.2 release.
2-
import asyncio
3-
from typing import Optional, Dict, Callable
1+
from typing import Callable
42

5-
import ray
6-
from ray.util.queue import Queue as RayQueue, Empty, Full
7-
8-
9-
class Unavailable:
10-
"""No object should be instance of this class"""
11-
12-
def __init__(self, *args, **kwargs):
13-
raise RuntimeError("This class should never be instantiated.")
14-
15-
16-
# Remove after Ray 1.2 release.
17-
if getattr(RayQueue, "shutdown", None) is not None:
18-
from ray.util.queue import _QueueActor
19-
else:
20-
# On Ray <v1.2, we have to create our own class so we can create it with
21-
# custom resources.
22-
class _QueueActor:
23-
"""A class with basic Queue functionality."""
3+
from pytorch_lightning.accelerators import GPUAccelerator
4+
from pytorch_lightning import Trainer, LightningModule
245

25-
def __init__(self, maxsize):
26-
self.maxsize = maxsize
27-
self.queue = asyncio.Queue(self.maxsize)
28-
29-
def qsize(self):
30-
return self.queue.qsize()
31-
32-
def empty(self):
33-
return self.queue.empty()
34-
35-
def full(self):
36-
return self.queue.full()
37-
38-
async def put(self, item, timeout=None):
39-
try:
40-
await asyncio.wait_for(self.queue.put(item), timeout)
41-
except asyncio.TimeoutError:
42-
raise Full
6+
import ray
437

44-
async def get(self, timeout=None):
45-
try:
46-
return await asyncio.wait_for(self.queue.get(), timeout)
47-
except asyncio.TimeoutError:
48-
raise Empty
498

50-
def put_nowait(self, item):
51-
self.queue.put_nowait(item)
9+
class DelayedGPUAccelerator(GPUAccelerator):
10+
"""Same as GPUAccelerator, but doesn't do any CUDA setup.
5211
53-
def put_nowait_batch(self, items):
54-
# If maxsize is 0, queue is unbounded, so no need to check size.
55-
if self.maxsize > 0 and len(items) + self.qsize() > self.maxsize:
56-
raise Full(f"Cannot add {len(items)} items to queue of size "
57-
f"{self.qsize()} and maxsize {self.maxsize}.")
58-
for item in items:
59-
self.queue.put_nowait(item)
12+
This allows the driver script to be launched from CPU-only machines (
13+
like the laptop) but have training still execute on GPU.
14+
"""
6015

61-
def get_nowait(self):
62-
return self.queue.get_nowait()
16+
def setup(self, trainer: Trainer, model: LightningModule) -> None:
17+
return super(GPUAccelerator, self).setup(trainer, model)
6318

64-
def get_nowait_batch(self, num_items):
65-
if num_items > self.qsize():
66-
raise Empty(f"Cannot get {num_items} items from queue of size "
67-
f"{self.qsize()}.")
68-
return [self.queue.get_nowait() for _ in range(num_items)]
19+
def on_train_start(self) -> None:
20+
super(DelayedGPUAccelerator, self).on_train_start()
6921

7022

71-
class Queue(RayQueue):
72-
def __init__(self, maxsize: int = 0,
73-
actor_options: Optional[Dict] = None) -> None:
74-
actor_options = actor_options or {}
75-
self.maxsize = maxsize
76-
self.actor = ray.remote(_QueueActor).options(**actor_options).remote(
77-
self.maxsize)
23+
class Unavailable:
24+
"""No object should be instance of this class"""
7825

79-
def shutdown(self):
80-
if getattr(RayQueue, "shutdown", None) is not None:
81-
super(Queue, self).shutdown()
82-
else:
83-
if self.actor:
84-
ray.kill(self.actor)
85-
self.actor = None
26+
def __init__(self, *args, **kwargs):
27+
raise RuntimeError("This class should never be instantiated.")
8628

8729

8830
def _handle_queue(queue):

0 commit comments

Comments
 (0)