1+ import io
12from typing import Callable , Dict , List , Union , Any
23
34import os
45from collections import defaultdict
56
6- import ray
77import torch
8+
9+ from pytorch_lightning .accelerators import CPUAccelerator
810from pytorch_lightning .plugins import DDPSpawnPlugin
911from pytorch_lightning import _logger as log , LightningModule
1012from pytorch_lightning .utilities import rank_zero_only
13+
14+ import ray
1115from ray .util .sgd .utils import find_free_port
16+ from ray .util .queue import Queue
1217
1318from ray_lightning .session import init_session
14- from ray_lightning .util import process_results , Queue
19+ from ray_lightning .util import process_results
1520from ray_lightning .tune import TUNE_INSTALLED , is_session_enabled
1621from ray_lightning .ray_environment import RayEnvironment
1722
@@ -161,6 +166,15 @@ def _setup_env_vars(self):
161166 values = [os .getenv (k ) for k in keys ]
162167 ray .get ([w .set_env_vars .remote (keys , values ) for w in self .workers ])
163168
169+ def _load_state_stream (self , state_stream ):
170+ _buffer = io .BytesIO (state_stream )
171+ to_gpu = self .use_gpu and torch .cuda .is_available ()
172+ state_dict = torch .load (
173+ _buffer ,
174+ map_location = ("cpu" if not to_gpu
175+ else lambda storage , loc : storage .cuda ()))
176+ return state_dict
177+
164178 def execution_loop (self , trainer , tune_enabled : bool = True ):
165179 """Main execution loop for training, testing, & prediction.
166180
@@ -194,7 +208,8 @@ def execution_loop(self, trainer, tune_enabled: bool = True):
194208
195209 results = process_results (futures , queue )
196210 # Get the results, checkpoint path, and model weights from worker 0.
197- results , best_path , state_dict = results [0 ]
211+ results , best_path , state_stream = results [0 ]
212+ state_dict = self ._load_state_stream (state_stream )
198213 # Set the state for PTL using the output from remote training.
199214 self ._results = results
200215 self ._model = model
@@ -209,6 +224,24 @@ def execution_loop(self, trainer, tune_enabled: bool = True):
209224
210225 return results
211226
227+ def setup_environment (self ) -> None :
228+ # Swap out the accelerator if necessary.
229+ # This is needed to support CPU head with GPU workers or Ray Client.
230+ current_accelerator = self .lightning_module .trainer .accelerator
231+ if self .use_gpu and isinstance (current_accelerator , CPUAccelerator ):
232+ from weakref import proxy
233+ from ray_lightning .util import DelayedGPUAccelerator
234+ precision_plugin = current_accelerator .precision_plugin
235+ new_accelerator = DelayedGPUAccelerator (
236+ precision_plugin = precision_plugin , training_type_plugin = self )
237+ self .lightning_module .trainer .accelerator_connector \
238+ ._training_type_plugin = \
239+ proxy (new_accelerator .training_type_plugin )
240+ self .lightning_module .trainer .accelerator_connector \
241+ ._precision_plugin = proxy (new_accelerator .precision_plugin )
242+ self .lightning_module .trainer .accelerator_connector .accelerator \
243+ = new_accelerator
244+
212245 def start_training (self , trainer ):
213246 results = self .execution_loop (trainer , tune_enabled = True )
214247 # reset optimizers, since main process is never used for training and
@@ -268,7 +301,7 @@ def execute_remote(self,
268301 mp_queue = None )
269302 # Only need results from worker 0.
270303 if self .global_rank == 0 :
271- return self .results , self .best_model_path , self .model_state_dict
304+ return self .results , self .best_model_path , self .model_state_stream
272305 else :
273306 return None
274307
@@ -307,12 +340,18 @@ def root_device(self):
307340 else :
308341 return torch .device ("cpu" )
309342
343+ def _to_state_stream (self , model_state_dict ):
344+ _buffer = io .BytesIO ()
345+ torch .save (model_state_dict , _buffer )
346+ return _buffer .getvalue ()
347+
310348 def transfer_distrib_spawn_state_on_fit_end (self , results ):
311349 """Sets the training output as attributes so it can be retrieved."""
312350 if self .global_rank == 0 :
313351 # Save training results as attributes.
314352 self ._results = results
315- self .model_state_dict = self .lightning_module .state_dict ()
353+ self .model_state_stream = \
354+ self ._to_state_stream (self .lightning_module .state_dict ())
316355 best_model_path = None
317356 if self .lightning_module .trainer .checkpoint_callback is not None :
318357 best_model_path = \
0 commit comments