Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 1059558

Browse files
MarkusSpanringmspanring
andauthored
move state dict to cpu before converting to state stream (#208)
Co-authored-by: mspanring <[email protected]>
1 parent d0be22e commit 1059558

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

ray_lightning/launchers/ray_launcher.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,9 @@ def _wrapping_function(
301301
results = function(*args, **kwargs)
302302

303303
if trainer is not None:
304-
results = self._collect_rank_zero_results(trainer, results)
305-
306-
if trainer.strategy.local_rank == 0:
307-
return move_data_to_device(results, "cpu")
304+
return self._collect_rank_zero_results(trainer, results)
305+
else:
306+
return None
308307

309308
trainer._teardown()
310309
trainer._call_teardown_hook()
@@ -326,6 +325,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",
326325
if self._strategy.global_rank != 0:
327326
return None
328327

328+
# Move state_dict to cpu before converting it to model state stream
329+
if trainer.strategy.local_rank == 0:
330+
state_dict = move_data_to_device(state_dict, "cpu")
331+
329332
# PyTorch Lightning saves the model weights in a temp file and
330333
# loads it back on the driver.
331334
# This won't work in a multi-node setup though, so we return the

0 commit comments

Comments
 (0)