Skip to content

Commit eadb998

Browse files
authored
[tune] Make HyperBand Usable (#1215)
1 parent 3a0206a commit eadb998

File tree

3 files changed

+348
-162
lines changed

3 files changed

+348
-162
lines changed

python/ray/rllib/tuned_examples/hyperband-cartpole.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cartpole-ppo:
22
env: CartPole-v0
33
alg: PPO
4-
num_trials: 20
4+
repeat: 3
55
stop:
66
episode_reward_mean: 200
77
time_total_s: 180

python/ray/tune/hyperband.py

Lines changed: 97 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from ray.tune.trial import Trial
99

1010

11-
def calculate_bracket_count(max_iter, eta):
12-
return int(np.log(max_iter)/np.log(eta)) + 1
13-
14-
1511
class HyperBandScheduler(FIFOScheduler):
1612
"""Implements HyperBand.
1713
14+
Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
15+
1816
This implementation contains 3 logical levels.
1917
Each HyperBand iteration is a "band". There can be multiple
2018
bands running at once, and there can be 1 band that is incomplete.
@@ -30,33 +28,48 @@ class HyperBandScheduler(FIFOScheduler):
3028
3129
Trials added will be inserted into the most recent bracket
3230
and band and will spill over to new brackets/bands accordingly.
33-
"""
3431
35-
def __init__(self, max_iter=200, eta=3):
36-
"""
37-
args:
38-
max_iter (int): maximum iterations per configuration
39-
eta (int): # defines downsampling rate (default=3)
40-
"""
41-
assert max_iter > 0, "Max Iterations not valid!"
42-
assert eta > 1, "Downsampling rate (eta) not valid!"
32+
This maintains the bracket size and max trial count per band
33+
to 5 and 117 respectively, which correspond to that of
34+
`max_attr=81, eta=3` from the blog post. Trials will fill up
35+
from smallest bracket to largest, with largest
36+
having the most rounds of successive halving.
37+
38+
Args:
39+
time_attr (str): The TrainingResult attr to use for comparing time.
40+
Note that you can pass in something non-temporal such as
41+
`training_iteration` as a measure of progress, the only requirement
42+
is that the attribute should increase monotonically.
43+
reward_attr (str): The TrainingResult objective value attribute. As
44+
with `time_attr`, this may refer to any objective value. Stopping
45+
procedures will use this attribute.
46+
max_t (int): max time units per trial. Trials will be stopped after
47+
max_t time units (determined by time_attr) have passed.
48+
The HyperBand scheduler automatically tries to determine a
49+
reasonable number of brackets based on this and eta.
50+
"""
4351

52+
def __init__(
53+
self, time_attr='training_iteration',
54+
reward_attr='episode_reward_mean', max_t=81):
55+
assert max_t > 0, "Max (time_attr) not valid!"
4456
FIFOScheduler.__init__(self)
45-
self._eta = eta
46-
self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta)
47-
# total number of iterations per execution of Succesive Halving (n,r)
48-
B = s_max_1 * max_iter
49-
# bracket trial count total
50-
self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s))
57+
self._eta = 3
58+
self._s_max_1 = 5
59+
# bracket max trials
60+
self._get_n0 = lambda s: int(
61+
np.ceil(self._s_max_1/(s+1) * self._eta**s))
5162
# bracket initial iterations
52-
self._get_r0 = lambda s: int(max_iter*eta**(-s))
63+
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
5364
self._hyperbands = [[]] # list of hyperband iterations
5465
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
5566

5667
# Tracks state for new trial add
5768
self._state = {"bracket": None,
5869
"band_idx": 0}
5970
self._num_stopped = 0
71+
self._reward_attr = reward_attr
72+
self._time_attr = time_attr
6073

6174
def on_trial_add(self, trial_runner, trial):
6275
"""On a new trial add, if current bracket is not filled,
@@ -67,22 +80,27 @@ def on_trial_add(self, trial_runner, trial):
6780
cur_bracket = self._state["bracket"]
6881
cur_band = self._hyperbands[self._state["band_idx"]]
6982
if cur_bracket is None or cur_bracket.filled():
70-
71-
# if current iteration is filled, create new iteration
72-
if self._cur_band_filled():
73-
cur_band = []
74-
self._hyperbands.append(cur_band)
75-
self._state["band_idx"] += 1
76-
77-
# cur_band will always be less than s_max_1 or else filled
78-
s = len(cur_band)
79-
assert s < self._s_max_1, "Current band is filled!"
80-
81-
# create new bracket
82-
cur_bracket = Bracket(self._get_n0(s),
83-
self._get_r0(s), self._eta, s)
84-
cur_band.append(cur_bracket)
85-
self._state["bracket"] = cur_bracket
83+
retry = True
84+
while retry:
85+
# if current iteration is filled, create new iteration
86+
if self._cur_band_filled():
87+
cur_band = []
88+
self._hyperbands.append(cur_band)
89+
self._state["band_idx"] += 1
90+
91+
# cur_band will always be less than s_max_1 or else filled
92+
s = len(cur_band)
93+
assert s < self._s_max_1, "Current band is filled!"
94+
if self._get_r0(s) == 0:
95+
print("Bracket too small - Retrying...")
96+
cur_bracket = None
97+
else:
98+
retry = False
99+
cur_bracket = Bracket(
100+
self._time_attr, self._get_n0(s), self._get_r0(s),
101+
self._eta, s)
102+
cur_band.append(cur_bracket)
103+
self._state["bracket"] = cur_bracket
86104

87105
self._state["bracket"].add_trial(trial)
88106
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
@@ -128,9 +146,9 @@ def _process_bracket(self, trial_runner, bracket, trial):
128146
if bracket.cur_iter_done():
129147
if bracket.finished():
130148
self._cleanup_bracket(trial_runner, bracket)
131-
return TrialScheduler.STOP
149+
return TrialScheduler.CONTINUE
132150

133-
good, bad = bracket.successive_halving()
151+
good, bad = bracket.successive_halving(self._reward_attr)
134152
# kill bad trials
135153
for t in bad:
136154
if t.status == Trial.PAUSED:
@@ -141,14 +159,15 @@ def _process_bracket(self, trial_runner, bracket, trial):
141159
else:
142160
raise Exception("Trial with unexpected status encountered")
143161

144-
# ready the good trials
162+
# ready the good trials - if trial is too far ahead, don't continue
145163
for t in good:
146-
if t.status == Trial.PAUSED:
147-
t.unpause()
148-
elif t.status == Trial.RUNNING:
149-
action = TrialScheduler.CONTINUE
150-
else:
164+
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
151165
raise Exception("Trial with unexpected status encountered")
166+
if bracket.continue_trial(t):
167+
if t.status == Trial.PAUSED:
168+
t.unpause()
169+
elif t.status == Trial.RUNNING:
170+
action = TrialScheduler.CONTINUE
152171
return action
153172

154173
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
@@ -162,11 +181,14 @@ def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
162181
bracket.cleanup_trial(t)
163182

164183
def _cleanup_bracket(self, trial_runner, bracket):
165-
"""Cleans up bracket after bracket is completely finished."""
184+
"""Cleans up bracket after bracket is completely finished.
185+
Lets the last trial continue to run until termination condition
186+
kicks in."""
166187
for trial in bracket.current_trials():
167-
self._cleanup_trial(
168-
trial_runner, trial, bracket,
169-
hard=(trial.status == Trial.PAUSED))
188+
if (trial.status == Trial.PAUSED):
189+
self._cleanup_trial(
190+
trial_runner, trial, bracket,
191+
hard=True)
170192

171193
def on_trial_complete(self, trial_runner, trial, result):
172194
"""Cleans up trial info from bracket if trial completed early."""
@@ -219,12 +241,15 @@ class Bracket():
219241
220242
Also keeps track of progress to ensure good scheduling.
221243
"""
222-
def __init__(self, max_trials, init_iters, eta, s):
223-
self._live_trials = {} # stores (result, itrs left before halving)
244+
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
245+
self._live_trials = {} # maps trial -> current result
224246
self._all_trials = []
247+
self._time_attr = time_attr # attribute to
248+
225249
self._n = self._n0 = max_trials
226-
self._r = self._r0 = init_iters
250+
self._r = self._r0 = init_t_attr
227251
self._cumul_r = self._r0
252+
228253
self._eta = eta
229254
self._halves = s
230255

@@ -237,15 +262,15 @@ def add_trial(self, trial):
237262
At a later iteration, a newly added trial will be given equal
238263
opportunity to catch up."""
239264
assert not self.filled(), "Cannot add trial to filled bracket!"
240-
self._live_trials[trial] = (None, self._cumul_r)
265+
self._live_trials[trial] = None
241266
self._all_trials.append(trial)
242267

243268
def cur_iter_done(self):
244269
"""Checks if all iterations have completed.
245270
246271
TODO(rliaw): also check that `t.iterations == self._r`"""
247-
all_done = all(itr == 0 for _, itr in self._live_trials.values())
248-
return all_done
272+
return all(self._get_result_time(result) >= self._cumul_r
273+
for result in self._live_trials.values())
249274

250275
def finished(self):
251276
return self._halves == 0 and self.cur_iter_done()
@@ -254,8 +279,8 @@ def current_trials(self):
254279
return list(self._live_trials)
255280

256281
def continue_trial(self, trial):
257-
_, itr = self._live_trials[trial]
258-
if itr > 0:
282+
result = self._live_trials[trial]
283+
if self._get_result_time(result) < self._cumul_r:
259284
return True
260285
else:
261286
return False
@@ -265,24 +290,19 @@ def filled(self):
265290
minimizing the need to backtrack and bookkeep previous medians"""
266291
return len(self._live_trials) == self._n
267292

268-
def successive_halving(self):
293+
def successive_halving(self, reward_attr):
269294
assert self._halves > 0
270295
self._halves -= 1
271296
self._n /= self._eta
272297
self._n = int(np.ceil(self._n))
273298
self._r *= self._eta
274-
self._r = int(np.ceil(self._r))
299+
self._r = int((self._r))
275300
self._cumul_r += self._r
276301
sorted_trials = sorted(
277302
self._live_trials,
278-
key=lambda t: self._live_trials[t][0].episode_reward_mean)
303+
key=lambda t: getattr(self._live_trials[t], reward_attr))
279304

280305
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
281-
282-
# reset good trials to track updated iterations
283-
for t in good:
284-
res, old_itr = self._live_trials[t]
285-
self._live_trials[t] = (res, self._r)
286306
return good, bad
287307

288308
def update_trial_stats(self, trial, result):
@@ -293,10 +313,13 @@ def update_trial_stats(self, trial, result):
293313
in and make sure they're not set as pending later."""
294314

295315
assert trial in self._live_trials
296-
_, itr = self._live_trials[trial]
297-
assert itr > 0
298-
self._live_trials[trial] = (result, itr - 1)
299-
self._completed_progress += 1
316+
assert self._get_result_time(result) >= 0
317+
318+
delta = self._get_result_time(result) - \
319+
self._get_result_time(self._live_trials[trial])
320+
assert delta >= 0
321+
self._completed_progress += delta
322+
self._live_trials[trial] = result
300323

301324
def cleanup_trial(self, trial):
302325
"""Clean up statistics tracking for terminated trials (either by force
@@ -315,13 +338,19 @@ def completion_percentage(self):
315338
are dropped."""
316339
return self._completed_progress / self._total_work
317340

341+
def _get_result_time(self, result):
342+
if result is None:
343+
return 0
344+
return getattr(result, self._time_attr)
345+
318346
def _calculate_total_work(self, n, r, s):
319347
work = 0
320348
for i in range(s+1):
321349
work += int(n) * int(r)
322350
n /= self._eta
323351
n = int(np.ceil(n))
324352
r *= self._eta
353+
r = int(r)
325354
return work
326355

327356
def __repr__(self):

0 commit comments

Comments
 (0)