@@ -240,11 +240,14 @@ def _get_ip(_=None):
240240 if _state_provided :
241241 comp_state_ref = ray .put (state .pop (comp_name ))
242242
243+ # If worker_addr == self_addr, save directly to the path
244+ # provided by the user, make sure to use filesystem.
243245 if worker_ip_addr == self_ip_addr :
244246 comp .foreach_actor (
245- lambda w , _path = comp_path , _state = comp_state_ref , _use_msgpack = use_msgpack : ( # noqa
247+ lambda w , _path = comp_path , _filesystem = filesystem , _state = comp_state_ref , _use_msgpack = use_msgpack : ( # noqa
246248 w .save_to_path (
247- _path ,
249+ path = _path ,
250+ filesystem = _filesystem ,
248251 state = (
249252 ray .get (_state )
250253 if _state is not None
@@ -255,6 +258,7 @@ def _get_ip(_=None):
255258 ),
256259 remote_actor_ids = [actor_to_use ],
257260 )
261+ # Transfer state files from the worker node to the head node
258262 else :
259263 # Save the checkpoint to the temporary directory on the worker.
260264 def _save (w , _state = comp_state_ref , _use_msgpack = use_msgpack ):
@@ -263,7 +267,7 @@ def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
263267 # Create a temporary directory on the worker.
264268 tmpdir = tempfile .mkdtemp ()
265269 w .save_to_path (
266- tmpdir ,
270+ path = tmpdir ,
267271 state = (
268272 ray .get (_state ) if _state is not None else w .get_state ()
269273 ),
@@ -304,7 +308,7 @@ def _rmdir(_, _dir=worker_temp_dir):
304308 # have to call its own `get_state()` anymore, but uses what's provided
305309 # here.
306310 comp .save_to_path (
307- comp_path ,
311+ path = comp_path ,
308312 filesystem = filesystem ,
309313 state = comp_state ,
310314 use_msgpack = use_msgpack ,
@@ -387,7 +391,10 @@ def restore_from_path(
387391 # Restore components of `self` that themselves are `Checkpointable`.
388392 orig_comp_names = {c [0 ] for c in self .get_checkpointable_components ()}
389393 self ._restore_all_subcomponents_from_path (
390- path , filesystem , component = component , ** kwargs
394+ path = path ,
395+ filesystem = filesystem ,
396+ component = component ,
397+ ** kwargs ,
391398 )
392399
393400 # Restore the "base" state (not individual subcomponents).
@@ -410,7 +417,10 @@ def restore_from_path(
410417 diff_comp_names = new_comp_names - orig_comp_names
411418 if diff_comp_names :
412419 self ._restore_all_subcomponents_from_path (
413- path , filesystem , only_comp_names = diff_comp_names , ** kwargs
420+ path = path ,
421+ filesystem = filesystem ,
422+ only_comp_names = diff_comp_names ,
423+ ** kwargs ,
414424 )
415425
416426 @classmethod
@@ -668,7 +678,12 @@ def _restore(
668678 # directly from the path otherwise sync the checkpoint from the head
669679 # to the worker and load it from there.
670680 if worker_node_ip == _head_ip :
671- w .restore_from_path (_path , component = _comp_arg , ** _kwargs )
681+ w .restore_from_path (
682+ path = _path ,
683+ filesystem = filesystem ,
684+ component = _comp_arg ,
685+ ** _kwargs ,
686+ )
672687 else :
673688 with tempfile .TemporaryDirectory () as temp_dir :
674689 sync_dir_between_nodes (
0 commit comments