Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 5a28cd2

Browse files
authored
pass in ddp_kwargs (#28)
1 parent c13bbb1 commit 5a28cd2

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

ray_lightning/ray_ddp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict, List
1+
from typing import Callable, Dict, List, Union, Any
22

33
import os
44
from collections import defaultdict
@@ -65,6 +65,8 @@ class RayPlugin(DDPSpawnPlugin):
6565
Trainer to a value > 0.
6666
init_hook (Callable): A function to run on each worker
6767
upon instantiation.
68+
**ddp_kwargs: Additional arguments to pass into
69+
``DistributedDataParallel`` initialization
6870
6971
Example:
7072
@@ -88,10 +90,12 @@ def __init__(self,
8890
num_workers: int = 1,
8991
num_cpus_per_worker: int = 1,
9092
use_gpu: bool = False,
91-
init_hook: Callable = None):
93+
init_hook: Callable = None,
94+
**ddp_kwargs: Union[Any, Dict[str, Any]]):
9295
if not ray.is_initialized():
9396
ray.init()
94-
super().__init__(sync_batchnorm=None, parallel_devices=[])
97+
super().__init__(
98+
sync_batchnorm=None, parallel_devices=[], **ddp_kwargs)
9599
self.nickname = "ddp_ray"
96100
self.num_workers = num_workers
97101
self.num_cpus_per_worker = num_cpus_per_worker

ray_lightning/tests/test_ddp.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,18 @@ def test_early_stop(tmpdir, ray_start_2_cpus):
130130
trained_model = BoringModel.load_from_checkpoint(
131131
trainer.checkpoint_callback.best_model_path)
132132
assert trained_model.val_epoch == 2, trained_model.val_epoch
133+
134+
135+
def test_unused_parameters(tmpdir, ray_start_2_cpus):
136+
"""Tests if find_unused_parameters is properly passed to model."""
137+
model = BoringModel()
138+
plugin = RayPlugin(
139+
num_workers=2, use_gpu=False, find_unused_parameters=False)
140+
141+
class UnusedParameterCallback(Callback):
142+
def on_train_start(self, trainer, pl_module):
143+
assert trainer.model.find_unused_parameters is False
144+
145+
trainer = get_trainer(
146+
tmpdir, plugins=[plugin], callbacks=[UnusedParameterCallback()])
147+
trainer.fit(model)

0 commit comments

Comments
 (0)