Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit f2ffd82

Browse files
authored
Fix Tune GPU Checkpointing (#70)
* fix * add gpu test * update docs
1 parent c3b13a7 commit f2ffd82

File tree

4 files changed

+53
-23
lines changed

4 files changed

+53
-23
lines changed

ray_lightning/ray_ddp.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import io
21
import socket
32
from contextlib import closing
43
from typing import Callable, Dict, List, Union, Any
@@ -17,7 +16,8 @@
1716
from ray.util.queue import Queue
1817

1918
from ray_lightning.session import init_session
20-
from ray_lightning.util import process_results
19+
from ray_lightning.util import process_results, to_state_stream, \
20+
load_state_stream
2121
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled
2222
from ray_lightning.ray_environment import RayEnvironment
2323

@@ -174,15 +174,6 @@ def _setup_env_vars(self):
174174
values = [os.getenv(k) for k in keys]
175175
ray.get([w.set_env_vars.remote(keys, values) for w in self.workers])
176176

177-
def _load_state_stream(self, state_stream):
178-
_buffer = io.BytesIO(state_stream)
179-
to_gpu = self.use_gpu and torch.cuda.is_available()
180-
state_dict = torch.load(
181-
_buffer,
182-
map_location=("cpu" if not to_gpu
183-
else lambda storage, loc: storage.cuda()))
184-
return state_dict
185-
186177
def execution_loop(self, trainer, tune_enabled: bool = True):
187178
"""Main execution loop for training, testing, & prediction.
188179
@@ -217,7 +208,7 @@ def execution_loop(self, trainer, tune_enabled: bool = True):
217208
results = process_results(futures, queue)
218209
# Get the results, checkpoint path, and model weights from worker 0.
219210
results, best_path, state_stream = results[0]
220-
state_dict = self._load_state_stream(state_stream)
211+
state_dict = load_state_stream(state_stream, to_gpu=self.use_gpu)
221212
# Set the state for PTL using the output from remote training.
222213
self._results = results
223214
self._model = model
@@ -348,18 +339,13 @@ def root_device(self):
348339
else:
349340
return torch.device("cpu")
350341

351-
def _to_state_stream(self, model_state_dict):
352-
_buffer = io.BytesIO()
353-
torch.save(model_state_dict, _buffer)
354-
return _buffer.getvalue()
355-
356342
def transfer_distrib_spawn_state_on_fit_end(self, results):
357343
"""Sets the training output as attributes so it can be retrieved."""
358344
if self.global_rank == 0:
359345
# Save training results as attributes.
360346
self._results = results
361347
self.model_state_stream = \
362-
self._to_state_stream(self.lightning_module.state_dict())
348+
to_state_stream(self.lightning_module.state_dict())
363349
best_model_path = None
364350
if self.lightning_module.trainer.checkpoint_callback is not None:
365351
best_model_path = \

ray_lightning/tests/test_tune.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
import ray
5+
import torch
56
from ray import tune
67

78
from ray_lightning import RayPlugin, HorovodRayPlugin
@@ -84,3 +85,19 @@ def test_checkpoint_horovod(tmpdir, ray_start_4_cpus):
8485
"""Tests if Tune checkpointing works with HorovodRayAccelerator."""
8586
plugin = HorovodRayPlugin(num_hosts=1, num_slots=2, use_gpu=False)
8687
checkpoint_test(tmpdir, plugin)
88+
89+
90+
@pytest.mark.skipif(
91+
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
92+
def test_checkpoint_ddp_gpu(tmpdir, ray_start_4_cpus):
93+
"""Tests if Tune checkpointing works with RayAccelerator."""
94+
plugin = RayPlugin(num_workers=2, use_gpu=False)
95+
checkpoint_test(tmpdir, plugin)
96+
97+
98+
@pytest.mark.skipif(
99+
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
100+
def test_checkpoint_horovod_gpu(tmpdir, ray_start_4_cpus):
101+
"""Tests if Tune checkpointing works with HorovodRayAccelerator."""
102+
plugin = HorovodRayPlugin(num_hosts=1, num_slots=2, use_gpu=False)
103+
checkpoint_test(tmpdir, plugin)

ray_lightning/tune.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Dict, List, Union
22

3+
import fsspec
34
import os
45

5-
from pytorch_lightning.utilities.cloud_io import atomic_save
66
from pytorch_lightning import Trainer, LightningModule
77

88
from ray_lightning.session import put_queue, get_actor_rank
9-
from ray_lightning.util import Unavailable
9+
from ray_lightning.util import to_state_stream, Unavailable
1010

1111
try:
1212
from ray import tune
@@ -143,20 +143,23 @@ def __init__(self,
143143
self._filename = filename
144144

145145
@staticmethod
146-
def _create_checkpoint(checkpoint_dict: dict, global_step: int,
146+
def _create_checkpoint(checkpoint_stream, global_step: int,
147147
filename: str):
148148
with tune.checkpoint_dir(step=global_step) as checkpoint_dir:
149149
file_path = os.path.join(checkpoint_dir, filename)
150-
atomic_save(checkpoint_dict, file_path)
150+
with fsspec.open(file_path, "wb") as f:
151+
f.write(checkpoint_stream)
151152

152153
def _handle(self, trainer: Trainer, pl_module: LightningModule):
153154
if trainer.running_sanity_check:
154155
return
155156
checkpoint_dict = trainer.checkpoint_connector.dump_checkpoint()
157+
# Convert to a state stream first.
158+
checkpoint_stream = to_state_stream(checkpoint_dict)
156159
global_step = trainer.global_step
157160
if get_actor_rank() == 0:
158161
put_queue(lambda: self._create_checkpoint(
159-
checkpoint_dict, global_step, self._filename))
162+
checkpoint_stream, global_step, self._filename))
160163

161164
class TuneReportCheckpointCallback(TuneCallback):
162165
"""PyTorch Lightning to Tune reporting and checkpointing callback.

ray_lightning/util.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import io
12
from typing import Callable
23

4+
import torch
35
from pytorch_lightning.accelerators import GPUAccelerator
46
from pytorch_lightning import Trainer, LightningModule
57

@@ -51,3 +53,25 @@ def process_results(training_result_futures, queue):
5153
# Process any remaining items in queue.
5254
_handle_queue(queue)
5355
return ray.get(training_result_futures)
56+
57+
58+
def to_state_stream(model_state_dict):
59+
"""Converts the given state dict to a stream of bytes."""
60+
_buffer = io.BytesIO()
61+
torch.save(model_state_dict, _buffer)
62+
return _buffer.getvalue()
63+
64+
65+
def load_state_stream(state_stream, to_gpu):
66+
"""Converts the state stream to a state dict on the appropriate device.
67+
68+
Converts to GPU if ``to_gpu`` is True and CUDA is available.
69+
70+
"""
71+
_buffer = io.BytesIO(state_stream)
72+
to_gpu = to_gpu and torch.cuda.is_available()
73+
state_dict = torch.load(
74+
_buffer,
75+
map_location=("cpu"
76+
if not to_gpu else lambda storage, loc: storage.cuda()))
77+
return state_dict

0 commit comments

Comments
 (0)