Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 972d3dd

Browse files
authored
Add API Annotations (#88)
* add annotations * update test
1 parent bc59fd4 commit 972d3dd

File tree

5 files changed

+13
-1
lines changed

5 files changed

+13
-1
lines changed

ray_lightning/ray_ddp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytorch_lightning.utilities import rank_zero_only
1414

1515
import ray
16+
from ray.util import PublicAPI
1617
from ray.util.queue import Queue
1718

1819
from ray_lightning.session import init_session
@@ -54,6 +55,7 @@ def execute(self, fn: Callable, *args, **kwargs):
5455
return fn(*args, **kwargs)
5556

5657

58+
@PublicAPI(stability="beta")
5759
class RayPlugin(DDPSpawnPlugin):
5860
"""Pytorch Lightning plugin for DDP training on a Ray cluster.
5961

ray_lightning/ray_ddp_sharded.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
22

3+
from ray.util import PublicAPI
4+
35
from ray_lightning import RayPlugin
46

57

8+
@PublicAPI(stability="beta")
69
class RayShardedPlugin(RayPlugin, DDPSpawnShardedPlugin):
710
pass

ray_lightning/ray_horovod.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import ray
77
from ray import ObjectRef
8+
from ray.util import PublicAPI
89
from ray.util.queue import Queue
910

1011
from ray_lightning.session import init_session
@@ -28,6 +29,7 @@ def get_executable_cls():
2829
return None
2930

3031

32+
@PublicAPI(stability="beta")
3133
class HorovodRayPlugin(HorovodPlugin):
3234
"""Pytorch Lightning Plugin for Horovod training on a Ray cluster.
3335

ray_lightning/tests/test_ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from ray.util.client.ray_client_helpers import ray_start_client_server
3+
import ray._private.gcs_utils as gcs_utils
34
from torch.utils.data import DistributedSampler
45

56
from pl_bolts.datamodules import MNISTDataModule
@@ -45,7 +46,7 @@ def check_num_actor():
4546
plugin = RayPlugin(num_workers=num_workers)
4647
trainer = get_trainer(tmpdir, plugins=[plugin])
4748
trainer.fit(model)
48-
assert all(actor["State"] == ray.gcs_utils.ActorTableData.DEAD
49+
assert all(actor["State"] == gcs_utils.ActorTableData.DEAD
4950
for actor in list(ray.state.actors().values()))
5051

5152

ray_lightning/tune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55

66
from pytorch_lightning import Trainer, LightningModule
7+
from ray.util import PublicAPI
78

89
from ray_lightning.session import put_queue, get_actor_rank
910
from ray_lightning.util import to_state_stream, Unavailable
@@ -26,6 +27,7 @@ def is_session_enabled():
2627

2728
if TUNE_INSTALLED:
2829

30+
@PublicAPI(stability="beta")
2931
def get_tune_ddp_resources(num_workers: int = 1,
3032
cpus_per_worker: int = 1,
3133
use_gpu: bool = False) -> Dict[str, int]:
@@ -40,6 +42,7 @@ def get_tune_ddp_resources(num_workers: int = 1,
4042
bundles, strategy="PACK")
4143
return placement_group_factory
4244

45+
@PublicAPI(stability="beta")
4346
class TuneReportCallback(TuneCallback):
4447
"""Distributed PyTorch Lightning to Ray Tune reporting callback
4548
@@ -161,6 +164,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
161164
put_queue(lambda: self._create_checkpoint(
162165
checkpoint_stream, global_step, self._filename))
163166

167+
@PublicAPI(stability="beta")
164168
class TuneReportCheckpointCallback(TuneCallback):
165169
"""PyTorch Lightning to Tune reporting and checkpointing callback.
166170

0 commit comments

Comments
 (0)