Skip to content

Commit 45b5d6b

Browse files
[RLlib] Fix checkpointable (#60440)
## Description Checkpointable methods `restore_from_path` and `save_to_path` should make use of PyArrow filesystem (especially for the cloud storages use cases) to make sure that RLlib components are correctly saved or restored. --------- Signed-off-by: Kamil Kaczmarek <kamil@anyscale.com>
1 parent 88615ac commit 45b5d6b

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

rllib/utils/checkpoints.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,14 @@ def _get_ip(_=None):
240240
if _state_provided:
241241
comp_state_ref = ray.put(state.pop(comp_name))
242242

243+
# If worker_addr == self_addr, save directly to the path
244+
# provided by the user, make sure to use filesystem.
243245
if worker_ip_addr == self_ip_addr:
244246
comp.foreach_actor(
245-
lambda w, _path=comp_path, _state=comp_state_ref, _use_msgpack=use_msgpack: ( # noqa
247+
lambda w, _path=comp_path, _filesystem=filesystem, _state=comp_state_ref, _use_msgpack=use_msgpack: ( # noqa
246248
w.save_to_path(
247-
_path,
249+
path=_path,
250+
filesystem=_filesystem,
248251
state=(
249252
ray.get(_state)
250253
if _state is not None
@@ -255,6 +258,7 @@ def _get_ip(_=None):
255258
),
256259
remote_actor_ids=[actor_to_use],
257260
)
261+
# Transfer state files from the worker node to the head node
258262
else:
259263
# Save the checkpoint to the temporary directory on the worker.
260264
def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
@@ -263,7 +267,7 @@ def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
263267
# Create a temporary directory on the worker.
264268
tmpdir = tempfile.mkdtemp()
265269
w.save_to_path(
266-
tmpdir,
270+
path=tmpdir,
267271
state=(
268272
ray.get(_state) if _state is not None else w.get_state()
269273
),
@@ -304,7 +308,7 @@ def _rmdir(_, _dir=worker_temp_dir):
304308
# have to call its own `get_state()` anymore, but uses what's provided
305309
# here.
306310
comp.save_to_path(
307-
comp_path,
311+
path=comp_path,
308312
filesystem=filesystem,
309313
state=comp_state,
310314
use_msgpack=use_msgpack,
@@ -387,7 +391,10 @@ def restore_from_path(
387391
# Restore components of `self` that themselves are `Checkpointable`.
388392
orig_comp_names = {c[0] for c in self.get_checkpointable_components()}
389393
self._restore_all_subcomponents_from_path(
390-
path, filesystem, component=component, **kwargs
394+
path=path,
395+
filesystem=filesystem,
396+
component=component,
397+
**kwargs,
391398
)
392399

393400
# Restore the "base" state (not individual subcomponents).
@@ -410,7 +417,10 @@ def restore_from_path(
410417
diff_comp_names = new_comp_names - orig_comp_names
411418
if diff_comp_names:
412419
self._restore_all_subcomponents_from_path(
413-
path, filesystem, only_comp_names=diff_comp_names, **kwargs
420+
path=path,
421+
filesystem=filesystem,
422+
only_comp_names=diff_comp_names,
423+
**kwargs,
414424
)
415425

416426
@classmethod
@@ -668,7 +678,12 @@ def _restore(
668678
# directly from the path otherwise sync the checkpoint from the head
669679
# to the worker and load it from there.
670680
if worker_node_ip == _head_ip:
671-
w.restore_from_path(_path, component=_comp_arg, **_kwargs)
681+
w.restore_from_path(
682+
path=_path,
683+
filesystem=filesystem,
684+
component=_comp_arg,
685+
**_kwargs,
686+
)
672687
else:
673688
with tempfile.TemporaryDirectory() as temp_dir:
674689
sync_dir_between_nodes(

0 commit comments

Comments
 (0)