Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
<!--$UNCOMMENT(ray-lightning)=-->

# Distributed PyTorch Lightning Training on Ray

## Updated for pytorch and pytorch-lightning 2


This library adds new PyTorch Lightning strategies for distributed training using the Ray distributed computing framework.

These PyTorch Lightning strategies on Ray enable quick and easy parallel training while still leveraging all the benefits of PyTorch Lightning and using your desired training protocol, either [PyTorch Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or [Horovod](https://github.com/horovod/horovod).
Expand Down
6 changes: 3 additions & 3 deletions ray_lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ray_lightning.ray_ddp import RayStrategy
from ray_lightning.ray_horovod import HorovodRayStrategy
from ray_lightning.ray_ddp_sharded import RayShardedStrategy
# from ray_lightning.ray_horovod import HorovodRayStrategy
# from ray_lightning.ray_ddp_sharded import RayShardedStrategy

__all__ = ["RayStrategy", "HorovodRayStrategy", "RayShardedStrategy"]
__all__ = ["RayStrategy"]#, "HorovodRayStrategy", "RayShardedStrategy"]
6 changes: 3 additions & 3 deletions ray_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators.registry import \
call_register_accelerators # noqa: F401
from pytorch_lightning.accelerators import AcceleratorRegistry
from lightning_fabric.accelerators.registry import call_register_accelerators
from ray_lightning.accelerators.delayed_gpu_accelerator import _GPUAccelerator

# these lines are to register the delayed gpu accelerator as `_gpu`
ACCELERATORS_BASE_MODULE = "ray_lightning.accelerators"
call_register_accelerators(ACCELERATORS_BASE_MODULE)
call_register_accelerators(AcceleratorRegistry, ACCELERATORS_BASE_MODULE)

__all__ = ["_GPUAccelerator"]
4 changes: 2 additions & 2 deletions ray_lightning/accelerators/delayed_gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch

from pytorch_lightning.accelerators import Accelerator,\
GPUAccelerator
CUDAAccelerator


class _GPUAccelerator(GPUAccelerator):
class _GPUAccelerator(CUDAAccelerator):
"""Accelerator for GPU devices.

adapted from:
Expand Down
4 changes: 2 additions & 2 deletions ray_lightning/launchers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ray_lightning.launchers.ray_launcher import RayLauncher
from ray_lightning.launchers.ray_horovod_launcher import RayHorovodLauncher
# from ray_lightning.launchers.ray_horovod_launcher import RayHorovodLauncher

__all__ = ["RayLauncher", "RayHorovodLauncher"]
__all__ = ["RayLauncher"]#, "RayHorovodLauncher"]
3 changes: 1 addition & 2 deletions ray_lightning/launchers/ray_horovod_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers import _Launcher
from pytorch_lightning.utilities.apply_func import apply_to_collection, \
move_data_to_device
from lightning_utilities.core.apply_func import apply_to_collection
import numpy as np
import torch

Expand Down
32 changes: 23 additions & 9 deletions ray_lightning/launchers/ray_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import os

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers import _Launcher
from pytorch_lightning.utilities.apply_func import apply_to_collection,\
move_data_to_device
from pytorch_lightning.strategies.launchers.launcher import _Launcher
from lightning_utilities.core.apply_func import apply_to_collection
import numpy as np
import torch

Expand All @@ -24,7 +23,19 @@
RayExecutor


class RayLauncher(_Launcher):
# def to_cpu(d):
# if hasattr(d, "cpu"):
# return d.cpu()
# elif isinstance(d, list):
# return [to_cpu(x) for x in d]
# elif isinstance(d, tuple):
# return tuple([to_cpu(x) for x in d])
# elif isinstance(d, dict):
# return {k: to_cpu(v) for k,v in d.items()}
# else:
# return d

class RayLauncher:
def __init__(self, strategy: "Strategy") -> None:
"""Initializes RayLauncher."""
self._strategy = strategy
Expand Down Expand Up @@ -54,6 +65,7 @@ def launch(self,

This function is run on the driver process.
"""
print("RayLauncher.launch", function)
self.setup_workers()
ray_output = self.run_function_on_workers(
function, *args, trainer=trainer, **kwargs)
Expand Down Expand Up @@ -231,9 +243,9 @@ def run_function_on_workers(self,
"""
# put the model as the ray object
# and remove the model temporarily from the args
model = trainer.model
model = trainer.strategy.model
model_ref = ray.put(model)
trainer.model = None
trainer.strategy.model = None
new_args = tuple([None] + list(args[1:]))

# train the model and get the result to rank 0 node
Expand All @@ -244,7 +256,7 @@ def run_function_on_workers(self,
for i, w in enumerate(self._workers)
]

trainer.model = model
trainer.strategy.model = model

results = process_results(self._futures, self.tune_queue)
return results[0]
Expand Down Expand Up @@ -284,7 +296,7 @@ def _wrapping_function(
# by calling `function.__self__` so that we can restore
# all the side effects happened to `function.__self__`
trainer = function.__self__
trainer.model = model_ref
trainer.strategy.model = model_ref
args = tuple([model_ref] + list(args[1:]))

trainer._data_connector.prepare_data()
Expand All @@ -293,6 +305,7 @@ def _wrapping_function(
init_session(rank=global_rank, queue=tune_queue)

self._strategy._worker_setup(process_idx=global_rank)
trainer.strategy.set_remote(True)
trainer.strategy.root_device = self._strategy.root_device
trainer.strategy.global_rank = self._strategy.global_rank
trainer.strategy.local_rank = self._strategy.local_rank
Expand Down Expand Up @@ -327,7 +340,8 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",

# Move state_dict to cpu before converting it to model state stream
if trainer.strategy.local_rank == 0:
state_dict = move_data_to_device(state_dict, "cpu")
#state_dict = move_data_to_device(state_dict, "cpu")
state_dict = {k: v.cpu() for k, v in state_dict.items()}

# PyTorch Lightning saves the model weights in a temp file and
# loads it back on the driver.
Expand Down
8 changes: 4 additions & 4 deletions ray_lightning/launchers/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Optional, NamedTuple, Dict, List, Callable
from pytorch_lightning.utilities.types import _PATH
# from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.trainer.states import TrainerState

from contextlib import closing
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_node_ip(self):
return ray.util.get_node_ip_address()

def get_node_and_gpu_ids(self):
return ray.get_runtime_context().node_id.hex(), ray.get_gpu_ids()
return ray.get_runtime_context().get_node_id(), ray.get_gpu_ids()

def execute(self, fn: Callable, *args, **kwargs):
"""Execute the provided function and return the result."""
Expand All @@ -61,8 +61,8 @@ class _RayOutput(NamedTuple):
- `callback_results`: callback result
- `logged_metrics`: logged metrics
"""
best_model_path: Optional[_PATH]
weights_path: Optional[_PATH]
best_model_path: Optional[Any]
weights_path: Optional[Any]
trainer_state: TrainerState
trainer_results: Any
callback_metrics: Dict[str, Any]
Expand Down
21 changes: 14 additions & 7 deletions ray_lightning/ray_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import torch

from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities.rank_zero import rank_zero_only

import ray
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed, log
#from pytorch_lightning.utilities.seed import reset_seed, log
from lightning_fabric.utilities.seed import reset_seed, log
from ray.util import PublicAPI

from ray_lightning.launchers import RayLauncher
Expand All @@ -20,7 +21,7 @@


@PublicAPI(stability="beta")
class RayStrategy(DDPSpawnStrategy):
class RayStrategy(DDPStrategy):
"""Pytorch Lightning strategy for DDP training on a Ray cluster.

This strategy is used to manage distributed training using DDP and
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self,
self._is_remote = False
self._device = None

ddp_kwargs["start_method"] = "spawn"
super().__init__(
accelerator="_gpu" if use_gpu else "cpu",
parallel_devices=[],
Expand Down Expand Up @@ -155,8 +157,12 @@ def set_world_ranks(self, process_idx: int = 0):
# False, then do a no-op).
if self._is_remote:
self._global_rank = process_idx
self._local_rank, self._node_rank = self.global_to_local[
self.global_rank]
self._local_rank, self._node_rank = self.global_to_local[self.global_rank]

def setup_environment(self) -> None:
assert self.accelerator is not None
self.accelerator.setup_device(self.root_device)
# return super(DDPStrategy).setup_environment()

def _worker_setup(self, process_idx: int):
"""Setup the workers and pytorch DDP connections.
Expand All @@ -182,7 +188,7 @@ def _worker_setup(self, process_idx: int):

global_rank = self.global_rank
world_size = self.world_size
torch_distributed_backend = self.torch_distributed_backend
torch_distributed_backend = self.process_group_backend

# Taken from pytorch_lightning.utilities.distributed
if torch.distributed.is_available(
Expand Down Expand Up @@ -329,5 +335,6 @@ def teardown(self) -> None:
This function is overriding ddp_spawn_strategy's method.
It is run on the driver processes.
"""
self.accelerator = None
super().teardown()
if not self._is_remote:
self.accelerator = None
8 changes: 4 additions & 4 deletions ray_lightning/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Callable

import torch
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.accelerators import CUDAAccelerator
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import Strategy
from pytorch_lightning.utilities.rank_zero import rank_zero_info

import ray


class DelayedGPUAccelerator(GPUAccelerator):
class DelayedGPUAccelerator(CUDAAccelerator):
"""Same as GPUAccelerator, but doesn't do any CUDA setup.

This allows the driver script to be launched from CPU-only machines (
Expand All @@ -21,7 +21,7 @@ def setup_environment(self) -> None:
# Don't do any CUDA setup.
# Directly call the setup_environment method of the superclass of
# GPUAccelerator.
super(GPUAccelerator, self).setup_environment()
super(CUDAAccelerator, self).setup_environment()

def setup(
self,
Expand All @@ -30,7 +30,7 @@ def setup(
# Don't do any CUDA setup.
# Directly call the setup_environment method of the superclass of
# GPUAccelerator.
return super(GPUAccelerator, self).setup(trainer)
return super(CUDAAccelerator, self).setup(trainer)

def on_train_start(self) -> None:
if "cuda" not in str(self.root_device):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
setup(
name="ray_lightning",
packages=find_packages(where=".", include="ray_lightning*"),
version="0.3.0",
version="0.4.0",
author="Ray Team",
description="Ray distributed strategies for Pytorch Lightning.",
long_description="Custom Pytorch Lightning distributed strategies "
"built on top of distributed computing framework Ray.",
url="https://github.com/ray-project/ray_lightning_accelerators",
install_requires=["pytorch-lightning==1.6.*", "ray"])
install_requires=["pytorch-lightning>=2", "ray"])