|
1 | | -import io |
2 | 1 | import socket |
3 | 2 | from contextlib import closing |
4 | 3 | from typing import Callable, Dict, List, Union, Any |
|
17 | 16 | from ray.util.queue import Queue |
18 | 17 |
|
19 | 18 | from ray_lightning.session import init_session |
20 | | -from ray_lightning.util import process_results |
| 19 | +from ray_lightning.util import process_results, to_state_stream, \ |
| 20 | + load_state_stream |
21 | 21 | from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled |
22 | 22 | from ray_lightning.ray_environment import RayEnvironment |
23 | 23 |
|
@@ -174,15 +174,6 @@ def _setup_env_vars(self): |
174 | 174 | values = [os.getenv(k) for k in keys] |
175 | 175 | ray.get([w.set_env_vars.remote(keys, values) for w in self.workers]) |
176 | 176 |
|
177 | | - def _load_state_stream(self, state_stream): |
178 | | - _buffer = io.BytesIO(state_stream) |
179 | | - to_gpu = self.use_gpu and torch.cuda.is_available() |
180 | | - state_dict = torch.load( |
181 | | - _buffer, |
182 | | - map_location=("cpu" if not to_gpu |
183 | | - else lambda storage, loc: storage.cuda())) |
184 | | - return state_dict |
185 | | - |
186 | 177 | def execution_loop(self, trainer, tune_enabled: bool = True): |
187 | 178 | """Main execution loop for training, testing, & prediction. |
188 | 179 |
|
@@ -217,7 +208,7 @@ def execution_loop(self, trainer, tune_enabled: bool = True): |
217 | 208 | results = process_results(futures, queue) |
218 | 209 | # Get the results, checkpoint path, and model weights from worker 0. |
219 | 210 | results, best_path, state_stream = results[0] |
220 | | - state_dict = self._load_state_stream(state_stream) |
| 211 | + state_dict = load_state_stream(state_stream, to_gpu=self.use_gpu) |
221 | 212 | # Set the state for PTL using the output from remote training. |
222 | 213 | self._results = results |
223 | 214 | self._model = model |
@@ -348,18 +339,13 @@ def root_device(self): |
348 | 339 | else: |
349 | 340 | return torch.device("cpu") |
350 | 341 |
|
351 | | - def _to_state_stream(self, model_state_dict): |
352 | | - _buffer = io.BytesIO() |
353 | | - torch.save(model_state_dict, _buffer) |
354 | | - return _buffer.getvalue() |
355 | | - |
356 | 342 | def transfer_distrib_spawn_state_on_fit_end(self, results): |
357 | 343 | """Sets the training output as attributes so it can be retrieved.""" |
358 | 344 | if self.global_rank == 0: |
359 | 345 | # Save training results as attributes. |
360 | 346 | self._results = results |
361 | 347 | self.model_state_stream = \ |
362 | | - self._to_state_stream(self.lightning_module.state_dict()) |
| 348 | + to_state_stream(self.lightning_module.state_dict()) |
363 | 349 | best_model_path = None |
364 | 350 | if self.lightning_module.trainer.checkpoint_callback is not None: |
365 | 351 | best_model_path = \ |
|
0 commit comments