Skip to content

Commit 71f8cd2

Browse files
authored
[tune] Fixing up Hyperband (#1207)
* Fixing up Hyperband * nit * cleanup * Timing test Added * added_exception_back * fixup_tests * reverse placement * fixes_and_tests * fix * fix * fixlint * cleanup_timing * lint * Update hyperband.py
1 parent 7c38f96 commit 71f8cd2

File tree

2 files changed

+175
-160
lines changed

2 files changed

+175
-160
lines changed

python/ray/tune/hyperband.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,9 @@ def __init__(self, max_iter=200, eta=3):
6060

6161
def on_trial_add(self, trial_runner, trial):
6262
"""On a new trial add, if current bracket is not filled,
63-
add to current bracket. Else, if current hp iteration is not filled,
63+
add to current bracket. Else, if current band is not filled,
6464
create new bracket, add to current bracket.
65-
Else, create new iteration, create new bracket, add to bracket.
66-
67-
TODO(rliaw): This is messy."""
65+
Else, create new iteration, create new bracket, add to bracket."""
6866

6967
cur_bracket = self._state["bracket"]
7068
cur_band = self._hyperbands[self._state["band_idx"]]
@@ -76,9 +74,9 @@ def on_trial_add(self, trial_runner, trial):
7674
self._hyperbands.append(cur_band)
7775
self._state["band_idx"] += 1
7876

79-
# cur_band will always be less than s_max or else filled
80-
s = self._s_max_1 - len(cur_band) - 1
81-
assert s >= 0, "Current band is filled but adding bracket!"
77+
# cur_band will always be less than s_max_1 or else filled
78+
s = len(cur_band)
79+
assert s < self._s_max_1, "Current band is filled!"
8280

8381
# create new bracket
8482
cur_bracket = Bracket(self._get_n0(s),
@@ -102,73 +100,87 @@ def on_trial_result(self, trial_runner, trial, result):
102100
103101
If a given trial finishes and bracket iteration is not done,
104102
the trial will be paused and resources will be given up.
105-
When bracket iteration is done, Trials will be successively halved,
106-
and during each halving phase, bad trials will be stopped while good
107-
trials will return to "PENDING". This scheduler will not start trials
108-
but will stop trials. The current running trial will not be handled,
103+
104+
This scheduler will not start trials but will stop trials.
105+
The current running trial will not be handled,
109106
as the trialrunner will be given control to handle it.
110107
111108
# TODO(rliaw) should be only called if trial has not errored"""
112109
bracket, _ = self._trial_info[trial]
113110
bracket.update_trial_stats(trial, result)
111+
114112
if bracket.continue_trial(trial):
115113
return TrialScheduler.CONTINUE
116114

117-
signal = TrialScheduler.PAUSE
115+
action = self._process_bracket(trial_runner, bracket, trial)
116+
return action
117+
118+
def _process_bracket(self, trial_runner, bracket, trial):
119+
"""This is called whenever a trial makes progress.
120+
121+
When all live trials in the bracket have no more iterations left,
122+
Trials will be successively halved. If bracket is done, all
123+
non-running trials will be stopped and cleaned up,
124+
and during each halving phase, bad trials will be stopped while good
125+
trials will return to "PENDING"."""
118126

127+
action = TrialScheduler.PAUSE
119128
if bracket.cur_iter_done():
120129
if bracket.finished():
121130
self._cleanup_bracket(trial_runner, bracket)
122131
return TrialScheduler.STOP
123-
# what if bracket is done and trial not completed?
132+
124133
good, bad = bracket.successive_halving()
125134
# kill bad trials
126135
for t in bad:
127-
self._num_stopped += 1
128136
if t.status == Trial.PAUSED:
129-
trial_runner._stop_trial(t)
130-
bracket.cleanup_trial_early(t)
131-
elif t is trial:
132-
signal = TrialScheduler.STOP
137+
self._cleanup_trial(trial_runner, t, bracket, hard=True)
138+
elif t.status == Trial.RUNNING:
139+
self._cleanup_trial(trial_runner, t, bracket, hard=False)
140+
action = TrialScheduler.STOP
133141
else:
134142
raise Exception("Trial with unexpected status encountered")
135143

136144
# ready the good trials
137145
for t in good:
138146
if t.status == Trial.PAUSED:
139147
t.unpause()
140-
elif t is trial:
141-
signal = TrialScheduler.CONTINUE
148+
elif t.status == Trial.RUNNING:
149+
action = TrialScheduler.CONTINUE
142150
else:
143151
raise Exception("Trial with unexpected status encountered")
152+
return action
144153

145-
return signal
154+
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
155+
"""Bookkeeping for trials finished. If `hard=True`, then
156+
this scheduler will force the trial_runner to release resources.
146157
147-
def _cleanup_bracket(self, trial_runner, bracket):
148-
"""Cleans up bracket after bracket is completely finished.
158+
Otherwise, only clean up trial information locally."""
159+
self._num_stopped += 1
160+
if hard:
161+
trial_runner._stop_trial(t)
162+
bracket.cleanup_trial(t)
149163

150-
Bracket information will only be cleaned up after the trialrunner has
151-
finished its bookkeeping."""
152-
for t in bracket.current_trials():
153-
if t.status == Trial.PAUSED:
154-
trial_runner._stop_trial(t)
155-
bracket.cleanup_trial_early(t)
164+
def _cleanup_bracket(self, trial_runner, bracket):
165+
"""Cleans up bracket after bracket is completely finished."""
166+
for trial in bracket.current_trials():
167+
self._cleanup_trial(
168+
trial_runner, trial, bracket,
169+
hard=(trial.status == Trial.PAUSED))
156170

157171
def on_trial_complete(self, trial_runner, trial, result):
158-
"""Cleans up trial info from bracket if trial completed early.
172+
"""Cleans up trial info from bracket if trial completed early."""
159173

160-
Bracket information will only be cleaned up after the trialrunner has
161-
finished its bookkeeping."""
162174
bracket, _ = self._trial_info[trial]
163-
bracket.cleanup_trial_early(trial)
175+
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
176+
self._process_bracket(trial_runner, bracket, trial)
164177

165178
def on_trial_error(self, trial_runner, trial):
166-
"""Cleans up trial info from bracket if trial errored early.
179+
"""Cleans up trial info from bracket if trial errored early."""
167180

168-
Bracket information will only be cleaned up after the trialrunner has
169-
finished its bookkeeping."""
170181
bracket, _ = self._trial_info[trial]
171-
bracket.cleanup_trial_early(trial)
182+
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
183+
self._process_bracket(trial_runner, bracket, trial)
172184

173185
def choose_trial_to_run(self, trial_runner, *args):
174186
"""Fair scheduling within iteration by completion percentage.
@@ -177,6 +189,7 @@ def choose_trial_to_run(self, trial_runner, *args):
177189
178190
If iteration is occupied (ie, no trials to run), then look into
179191
next iteration."""
192+
180193
for hyperband in self._hyperbands:
181194
for bracket in sorted(hyperband,
182195
key=lambda b: b.completion_percentage()):
@@ -187,10 +200,17 @@ def choose_trial_to_run(self, trial_runner, *args):
187200
return None
188201

189202
def debug_string(self):
203+
brackets = [
204+
"({0}/{1})".format(
205+
len(bracket._live_trials), len(bracket._all_trials))
206+
for band in self._hyperbands for bracket in band]
190207
return " ".join([
191208
"Using HyperBand:",
192209
"num_stopped={}".format(self._num_stopped),
193-
"brackets={}".format(sum(len(band) for band in self._hyperbands))])
210+
"total_brackets={}".format(
211+
sum(len(band) for band in self._hyperbands)),
212+
" ".join(brackets)
213+
])
194214

195215

196216
class Bracket():
@@ -278,8 +298,9 @@ def update_trial_stats(self, trial, result):
278298
self._live_trials[trial] = (result, itr - 1)
279299
self._completed_progress += 1
280300

281-
def cleanup_trial_early(self, trial):
282-
"""Clean up statistics tracking for trial that terminated early.
301+
def cleanup_trial(self, trial):
302+
"""Clean up statistics tracking for terminated trials (either by force
303+
or otherwise).
283304
284305
This may cause bad trials to continue for a long time, in the case
285306
where all the good trials finish early and there are only bad trials

0 commit comments

Comments
 (0)