Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit d5f325a

Browse files
authored
Fix hanging trainer.test() (#142)
Closes #132 Fixes an issue where trainer.test() hangs when using multiple workers in the test DataLoader. This issue is a bit weird as I was only able to reproduce the hanging with that exact setup. It does not occur with trainer.train() and multiple workers in the training DataLoader. It also does not occur when I set the num_workers for the test DataLoader to 0. I'm not exactly sure what's going on, but the testing actually finishes and the program hangs at torch.distributed.destroy_process_group(). It may potentially be related to pytorch/pytorch#75097. The difference between trainer.train() and trainer.test() is that the former wraps the model in DDP while the latter doesn't (but still creates the process group). In any case, the shudown_remote cleanup code is not actually necessary- the CUDA cache cleanup is already being called in the parent DDPSpawnPlugin on each worker, and it seems that torch.distributed.destroy_process_group() is not a Public API (and is not being called by PyTorch Lightning either). The test added in this PR hangs prior to the changes to ray_ddp.py, but is passing after.
1 parent 3c3e9d4 commit d5f325a

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

ray_lightning/ray_ddp.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,6 @@ def post_dispatch(self, trainer: "pl.Trainer"):
380380
.best_model_path = best_path
381381
# DDPSpawnPlugin.__recover_child_process_weights_end
382382

383-
def shutdown_remote():
384-
torch.distributed.destroy_process_group()
385-
if torch.cuda.is_available():
386-
torch.cuda.empty_cache()
387-
388-
ray.get([w.execute.remote(shutdown_remote) for w in self.workers])
389383
for w in self.workers:
390384
ray.kill(w, no_restart=True)
391385
del w

ray_lightning/tests/test_ddp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ def test_train_client(tmpdir, start_ray_client_server_2_cpus, num_workers):
227227
train_test(trainer, model)
228228

229229

230+
def test_test_with_dataloader_workers(tmpdir, ray_start_2_cpus, seed):
231+
"""Tests trainer.test with >0 workers for data loading."""
232+
model = BoringModel()
233+
plugin = RayPlugin(num_workers=1, use_gpu=False)
234+
trainer = get_trainer(
235+
tmpdir, limit_train_batches=20, max_epochs=1, plugins=[plugin])
236+
trainer.test(model)
237+
238+
230239
@pytest.mark.parametrize("num_workers", [1, 2])
231240
def test_load(tmpdir, ray_start_2_cpus, num_workers):
232241
"""Tests if model checkpoint can be loaded."""

ray_lightning/tests/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def val_dataloader(self):
8686
return torch.utils.data.DataLoader(RandomDataset(32, 64))
8787

8888
def test_dataloader(self):
89-
return torch.utils.data.DataLoader(RandomDataset(32, 64))
89+
return torch.utils.data.DataLoader(
90+
RandomDataset(32, 64), num_workers=1)
9091

9192
def on_save_checkpoint(self, checkpoint):
9293
checkpoint["val_epoch"] = self.val_epoch

0 commit comments

Comments
 (0)