Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 0a47405

Browse files
authored
PTL 1.2.10 Compatibility (#41)
* pass in ddp_kwargs * compat * format * hvd ranks default
1 parent f735548 commit 0a47405

File tree

5 files changed

+77
-15
lines changed

5 files changed

+77
-15
lines changed

ray_lightning/ray_ddp.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import torch
88
from pytorch_lightning.plugins import DDPSpawnPlugin
99
from pytorch_lightning import _logger as log, LightningModule
10+
from pytorch_lightning.utilities import rank_zero_only
1011
from ray.util.sgd.utils import find_free_port
1112

1213
from ray_lightning.session import init_session
1314
from ray_lightning.util import process_results, Queue
1415
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled
16+
from ray_lightning.ray_environment import RayEnvironment
1517

1618

1719
@ray.remote
@@ -95,13 +97,17 @@ def __init__(self,
9597
if not ray.is_initialized():
9698
ray.init()
9799
super().__init__(
98-
sync_batchnorm=None, parallel_devices=[], **ddp_kwargs)
100+
sync_batchnorm=None,
101+
parallel_devices=[],
102+
cluster_environment=RayEnvironment(world_size=num_workers),
103+
**ddp_kwargs)
99104
self.nickname = "ddp_ray"
100105
self.num_workers = num_workers
101106
self.num_cpus_per_worker = num_cpus_per_worker
102107
self.use_gpu = use_gpu
103108
self.workers = []
104109
self.init_hook = init_hook
110+
self._local_rank = 0
105111

106112
def _create_worker(self):
107113
"""Creates Ray actor."""
@@ -225,7 +231,7 @@ def train_remote(self,
225231
self.lightning_module.trainer.accelerator_connector\
226232
._training_type_plugin = self
227233
self.lightning_module.trainer.accelerator.training_type_plugin = self
228-
self.global_rank = global_rank
234+
self.cluster_environment.set_global_rank(global_rank)
229235

230236
if queue is not None:
231237
# Initialize session.
@@ -263,11 +269,14 @@ def init_ddp_connection(self,
263269
world_size=world_size,
264270
)
265271

266-
def set_world_ranks(self, process_idx: int):
272+
def set_world_ranks(self, process_idx: int = 0):
267273
"""Set the appropriate rank attribues for the trainer."""
268-
self.local_rank = self.global_to_local[self.global_rank]
269-
self.global_rank = self.global_rank
270-
self.world_size = self.num_workers
274+
assert self.cluster_environment is not None
275+
if self.global_rank is not None:
276+
self._local_rank = self.global_to_local[self.global_rank]
277+
self.cluster_environment.set_global_rank(self.global_rank)
278+
self.cluster_environment.set_world_size(self.num_workers)
279+
rank_zero_only.rank = self.cluster_environment.global_rank()
271280

272281
@property
273282
def root_device(self):

ray_lightning/ray_environment.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pytorch_lightning.plugins.environments import ClusterEnvironment
2+
from pytorch_lightning.utilities import rank_zero_only
3+
4+
5+
class RayEnvironment(ClusterEnvironment):
6+
"""Environment for PTL training on a Ray cluster."""
7+
8+
def __init__(self, world_size):
9+
self.set_world_size(world_size)
10+
self._global_rank = None
11+
12+
def creates_children(self) -> bool:
13+
return False
14+
15+
def master_address(self) -> str:
16+
raise NotImplementedError
17+
18+
def master_port(self) -> int:
19+
raise NotImplementedError
20+
21+
def world_size(self) -> int:
22+
return self._world_size
23+
24+
def set_world_size(self, size: int) -> None:
25+
self._world_size = size
26+
27+
def global_rank(self) -> int:
28+
return self._global_rank
29+
30+
def set_global_rank(self, rank: int) -> None:
31+
self._global_rank = rank
32+
rank_zero_only.rank = rank
33+
34+
def local_rank(self) -> int:
35+
raise NotImplementedError
36+
37+
def node_rank(self) -> int:
38+
raise NotImplementedError

ray_lightning/ray_horovod.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def __setstate__(self, d):
9292
d["executor"] = None
9393
self.__dict__.update(d)
9494

95+
@property
96+
def global_rank(self) -> int:
97+
if not hvd.is_initialized():
98+
return 0
99+
return hvd.rank()
100+
101+
@property
102+
def local_rank(self) -> int:
103+
if not hvd.is_initialized():
104+
return 0
105+
return hvd.local_rank()
106+
107+
@property
108+
def world_size(self) -> int:
109+
if not hvd.is_initialized():
110+
return self.num_hosts * self.num_slots
111+
return hvd.size()
112+
95113
def setup(self, model: LightningModule):
96114
"""Creates the RayExecutor object."""
97115
self._model = model
@@ -152,9 +170,6 @@ def train_remote(self, model: ObjectRef, queue: Queue = None, **kwargs):
152170
self.lightning_module.trainer.accelerator.training_type_plugin = self
153171

154172
hvd.init()
155-
self.global_rank = hvd.rank()
156-
self.local_rank = hvd.local_rank()
157-
self.world_size = hvd.size()
158173
rank_zero_only.rank = self.global_rank
159174

160175
if queue is not None:

ray_lightning/tests/test_ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def test_actor_creation(tmpdir, ray_start_2_cpus, num_workers):
3131
model = BoringModel()
3232

3333
def check_num_actor():
34-
assert len(ray.actors()) == num_workers
34+
assert len(ray.state.actors()) == num_workers
3535

3636
model.on_epoch_end = check_num_actor
3737
plugin = RayPlugin(num_workers=num_workers)
3838
trainer = get_trainer(tmpdir, plugins=[plugin])
3939
trainer.fit(model)
4040
assert all(actor["State"] == ray.gcs_utils.ActorTableData.DEAD
41-
for actor in list(ray.actors().values()))
41+
for actor in list(ray.state.actors().values()))
4242

4343

4444
def test_distributed_sampler(tmpdir, ray_start_2_cpus):

requirements-test.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
flake8==3.7.7
1+
flake8==3.9.1
22
flake8-comprehensions
3-
flake8-quotes==2.0.0
3+
flake8-quotes
44
yapf==0.23.0
55
pytest
6-
pytorch-lightning==1.2.6
7-
lightning-bolts==0.3.2
6+
pytorch-lightning==1.2.10
7+
lightning-bolts==0.3.3
88
ray[tune]
99
torchvision

0 commit comments

Comments
 (0)