Skip to content

Commit ae4e1dd

Browse files
authored
[tune] [rllib] Allow checkpointing to object store instead of local disk (#1212)
* wip * use normal pickle * fix checkpoint test * comment * Comment * fix test * fix lint * fix py 3.5 * Update agent.py * fix lint
1 parent d986294 commit ae4e1dd

File tree

3 files changed

+100
-12
lines changed

3 files changed

+100
-12
lines changed

python/ray/rllib/agent.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
import logging
88
import numpy as np
9+
import io
910
import os
11+
import gzip
1012
import pickle
13+
import shutil
1114
import tempfile
1215
import time
1316
import uuid
@@ -147,6 +150,35 @@ def save(self):
147150
open(checkpoint_path + ".rllib_metadata", "wb"))
148151
return checkpoint_path
149152

153+
def save_to_object(self):
154+
"""Saves the current model state to a Python object. It also
155+
saves to disk but does not return the checkpoint path.
156+
157+
Returns:
158+
Object holding checkpoint data.
159+
"""
160+
161+
checkpoint_prefix = self.save()
162+
163+
data = {}
164+
base_dir = os.path.dirname(checkpoint_prefix)
165+
for path in os.listdir(base_dir):
166+
path = os.path.join(base_dir, path)
167+
if path.startswith(checkpoint_prefix):
168+
data[os.path.basename(path)] = open(path, "rb").read()
169+
170+
out = io.BytesIO()
171+
with gzip.GzipFile(fileobj=out, mode="wb") as f:
172+
compressed = pickle.dumps({
173+
"checkpoint_name": os.path.basename(checkpoint_prefix),
174+
"data": data,
175+
})
176+
print("Saving checkpoint to object store, {} bytes".format(
177+
len(compressed)))
178+
f.write(compressed)
179+
180+
return out.getvalue()
181+
150182
def restore(self, checkpoint_path):
151183
"""Restores training state from a given model checkpoint.
152184
@@ -160,6 +192,25 @@ def restore(self, checkpoint_path):
160192
self._timesteps_total = metadata[2]
161193
self._time_total = metadata[3]
162194

195+
def restore_from_object(self, obj):
196+
"""Restores training state from a checkpoint object.
197+
198+
These checkpoints are returned from calls to save_to_object().
199+
"""
200+
201+
out = io.BytesIO(obj)
202+
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
203+
data = info["data"]
204+
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
205+
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
206+
207+
for file_name, file_contents in data.items():
208+
with open(os.path.join(tmpdir, file_name), "wb") as f:
209+
f.write(file_contents)
210+
211+
self.restore(checkpoint_path)
212+
shutil.rmtree(tmpdir)
213+
163214
def stop(self):
164215
"""Releases all resources used by this agent."""
165216

python/ray/rllib/test/test_checkpoint_restore.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def get_mean_action(alg, obs):
2626
"A3C": {"use_lstm": False},
2727
}
2828

29-
for name in ["ES", "DQN", "PPO", "A3C"]:
30-
cls = get_agent_class(name)
29+
30+
def test(use_object_store, alg_name):
31+
cls = get_agent_class(alg_name)
3132
alg1 = cls("CartPole-v0", CONFIGS[name])
3233
alg2 = cls("CartPole-v0", CONFIGS[name])
3334

@@ -36,11 +37,23 @@ def get_mean_action(alg, obs):
3637
print("current status: " + str(res))
3738

3839
# Sync the models
39-
alg2.restore(alg1.save())
40+
if use_object_store:
41+
alg2.restore_from_object(alg1.save_to_object())
42+
else:
43+
alg2.restore(alg1.save())
4044

4145
for _ in range(10):
4246
obs = np.random.uniform(size=4)
4347
a1 = get_mean_action(alg1, obs)
4448
a2 = get_mean_action(alg2, obs)
4549
print("Checking computed actions", alg1, obs, a1, a2)
46-
assert abs(a1-a2) < .1, (a1, a2)
50+
assert abs(a1 - a2) < .1, (a1, a2)
51+
52+
53+
if __name__ == "__main__":
54+
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
55+
for use_object_store in [False, True]:
56+
for name in ["ES", "DQN", "PPO", "A3C"]:
57+
test(use_object_store, name)
58+
59+
print("All checkpoint restore tests passed!")

python/ray/tune/trial.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
# Local trial state that is updated during the run
9292
self.last_result = None
9393
self._checkpoint_path = restore_path
94+
self._checkpoint_obj = None
9495
self.agent = None
9596
self.status = Trial.PENDING
9697
self.location = None
@@ -106,7 +107,9 @@ def start(self):
106107

107108
self._setup_agent()
108109
if self._checkpoint_path:
109-
self.restore_from_path(path=self._checkpoint_path)
110+
self.restore_from_path(self._checkpoint_path)
111+
elif self._checkpoint_obj:
112+
self.restore_from_obj(self._checkpoint_obj)
110113

111114
def stop(self, error=False, stop_logger=True):
112115
"""Stops this trial.
@@ -152,7 +155,7 @@ def pause(self):
152155

153156
assert self.status == Trial.RUNNING, self.status
154157
try:
155-
self.checkpoint()
158+
self.checkpoint(to_object_store=True)
156159
self.stop(stop_logger=False)
157160
self.status = Trial.PAUSED
158161
except Exception:
@@ -226,16 +229,25 @@ def location_string(hostname, pid):
226229

227230
return ', '.join(pieces)
228231

229-
def checkpoint(self):
230-
"""Synchronously checkpoints the state of this trial.
232+
def checkpoint(self, to_object_store=False):
233+
"""Checkpoints the state of this trial.
231234
232-
TODO(ekl): we should support a PAUSED state based on checkpointing.
235+
Args:
236+
to_object_store (bool): Whether to save to the Ray object store
237+
(async) vs a path on local disk (sync).
233238
"""
234239

235-
path = ray.get(self.agent.save.remote())
240+
obj = None
241+
path = None
242+
if to_object_store:
243+
obj = self.agent.save_to_object.remote()
244+
else:
245+
path = ray.get(self.agent.save.remote())
236246
self._checkpoint_path = path
237-
print("Saved checkpoint to:", path)
238-
return path
247+
self._checkpoint_obj = obj
248+
249+
print("Saved checkpoint to:", path or obj)
250+
return path or obj
239251

240252
def restore_from_path(self, path):
241253
"""Restores agent state from specified path.
@@ -253,6 +265,18 @@ def restore_from_path(self, path):
253265
print("Error restoring agent:", traceback.format_exc())
254266
self.status = Trial.ERROR
255267

268+
def restore_from_obj(self, obj):
269+
"""Restores agent state from the specified object."""
270+
271+
if self.agent is None:
272+
print("Unable to restore - no agent")
273+
else:
274+
try:
275+
ray.get(self.agent.restore_from_object.remote(obj))
276+
except Exception:
277+
print("Error restoring agent:", traceback.format_exc())
278+
self.status = Trial.ERROR
279+
256280
def _setup_agent(self):
257281
self.status = Trial.RUNNING
258282
agent_cls = get_agent_class(self.alg)

0 commit comments

Comments
 (0)