88from ray .tune .trial import Trial
99
1010
11- def calculate_bracket_count (max_iter , eta ):
12- return int (np .log (max_iter )/ np .log (eta )) + 1
13-
14-
1511class HyperBandScheduler (FIFOScheduler ):
1612 """Implements HyperBand.
1713
14+ Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
15+
1816 This implementation contains 3 logical levels.
1917 Each HyperBand iteration is a "band". There can be multiple
2018 bands running at once, and there can be 1 band that is incomplete.
@@ -30,33 +28,48 @@ class HyperBandScheduler(FIFOScheduler):
3028
3129 Trials added will be inserted into the most recent bracket
3230 and band and will spill over to new brackets/bands accordingly.
33- """
3431
35- def __init__ (self , max_iter = 200 , eta = 3 ):
36- """
37- args:
38- max_iter (int): maximum iterations per configuration
39- eta (int): # defines downsampling rate (default=3)
40- """
41- assert max_iter > 0 , "Max Iterations not valid!"
42- assert eta > 1 , "Downsampling rate (eta) not valid!"
32+ This maintains the bracket size and max trial count per band
33+ to 5 and 117 respectively, which correspond to that of
34+ `max_attr=81, eta=3` from the blog post. Trials will fill up
35+ from smallest bracket to largest, with largest
36+ having the most rounds of successive halving.
37+
38+ Args:
39+ time_attr (str): The TrainingResult attr to use for comparing time.
40+ Note that you can pass in something non-temporal such as
41+ `training_iteration` as a measure of progress, the only requirement
42+ is that the attribute should increase monotonically.
43+ reward_attr (str): The TrainingResult objective value attribute. As
44+ with `time_attr`, this may refer to any objective value. Stopping
45+ procedures will use this attribute.
46+ max_t (int): max time units per trial. Trials will be stopped after
47+ max_t time units (determined by time_attr) have passed.
48+ The HyperBand scheduler automatically tries to determine a
49+ reasonable number of brackets based on this and eta.
50+ """
4351
52+ def __init__ (
53+ self , time_attr = 'training_iteration' ,
54+ reward_attr = 'episode_reward_mean' , max_t = 81 ):
55+ assert max_t > 0 , "Max (time_attr) not valid!"
4456 FIFOScheduler .__init__ (self )
45- self ._eta = eta
46- self ._s_max_1 = s_max_1 = calculate_bracket_count (max_iter , eta )
47- # total number of iterations per execution of Succesive Halving (n,r)
48- B = s_max_1 * max_iter
49- # bracket trial count total
50- self ._get_n0 = lambda s : int (np .ceil (B / max_iter / (s + 1 )* eta ** s ))
57+ self ._eta = 3
58+ self ._s_max_1 = 5
59+ # bracket max trials
60+ self ._get_n0 = lambda s : int (
61+ np .ceil (self ._s_max_1 / (s + 1 ) * self ._eta ** s ))
5162 # bracket initial iterations
52- self ._get_r0 = lambda s : int (max_iter * eta ** (- s ))
63+ self ._get_r0 = lambda s : int (( max_t * self . _eta ** (- s ) ))
5364 self ._hyperbands = [[]] # list of hyperband iterations
5465 self ._trial_info = {} # Stores Trial -> Bracket, Band Iteration
5566
5667 # Tracks state for new trial add
5768 self ._state = {"bracket" : None ,
5869 "band_idx" : 0 }
5970 self ._num_stopped = 0
71+ self ._reward_attr = reward_attr
72+ self ._time_attr = time_attr
6073
6174 def on_trial_add (self , trial_runner , trial ):
6275 """On a new trial add, if current bracket is not filled,
@@ -67,22 +80,27 @@ def on_trial_add(self, trial_runner, trial):
6780 cur_bracket = self ._state ["bracket" ]
6881 cur_band = self ._hyperbands [self ._state ["band_idx" ]]
6982 if cur_bracket is None or cur_bracket .filled ():
70-
71- # if current iteration is filled, create new iteration
72- if self ._cur_band_filled ():
73- cur_band = []
74- self ._hyperbands .append (cur_band )
75- self ._state ["band_idx" ] += 1
76-
77- # cur_band will always be less than s_max_1 or else filled
78- s = len (cur_band )
79- assert s < self ._s_max_1 , "Current band is filled!"
80-
81- # create new bracket
82- cur_bracket = Bracket (self ._get_n0 (s ),
83- self ._get_r0 (s ), self ._eta , s )
84- cur_band .append (cur_bracket )
85- self ._state ["bracket" ] = cur_bracket
83+ retry = True
84+ while retry :
85+ # if current iteration is filled, create new iteration
86+ if self ._cur_band_filled ():
87+ cur_band = []
88+ self ._hyperbands .append (cur_band )
89+ self ._state ["band_idx" ] += 1
90+
91+ # cur_band will always be less than s_max_1 or else filled
92+ s = len (cur_band )
93+ assert s < self ._s_max_1 , "Current band is filled!"
94+ if self ._get_r0 (s ) == 0 :
95+ print ("Bracket too small - Retrying..." )
96+ cur_bracket = None
97+ else :
98+ retry = False
99+ cur_bracket = Bracket (
100+ self ._time_attr , self ._get_n0 (s ), self ._get_r0 (s ),
101+ self ._eta , s )
102+ cur_band .append (cur_bracket )
103+ self ._state ["bracket" ] = cur_bracket
86104
87105 self ._state ["bracket" ].add_trial (trial )
88106 self ._trial_info [trial ] = cur_bracket , self ._state ["band_idx" ]
@@ -128,9 +146,9 @@ def _process_bracket(self, trial_runner, bracket, trial):
128146 if bracket .cur_iter_done ():
129147 if bracket .finished ():
130148 self ._cleanup_bracket (trial_runner , bracket )
131- return TrialScheduler .STOP
149+ return TrialScheduler .CONTINUE
132150
133- good , bad = bracket .successive_halving ()
151+ good , bad = bracket .successive_halving (self . _reward_attr )
134152 # kill bad trials
135153 for t in bad :
136154 if t .status == Trial .PAUSED :
@@ -141,14 +159,15 @@ def _process_bracket(self, trial_runner, bracket, trial):
141159 else :
142160 raise Exception ("Trial with unexpected status encountered" )
143161
144- # ready the good trials
162+ # ready the good trials - if trial is too far ahead, don't continue
145163 for t in good :
146- if t .status == Trial .PAUSED :
147- t .unpause ()
148- elif t .status == Trial .RUNNING :
149- action = TrialScheduler .CONTINUE
150- else :
164+ if t .status not in [Trial .PAUSED , Trial .RUNNING ]:
151165 raise Exception ("Trial with unexpected status encountered" )
166+ if bracket .continue_trial (t ):
167+ if t .status == Trial .PAUSED :
168+ t .unpause ()
169+ elif t .status == Trial .RUNNING :
170+ action = TrialScheduler .CONTINUE
152171 return action
153172
154173 def _cleanup_trial (self , trial_runner , t , bracket , hard = False ):
@@ -162,11 +181,14 @@ def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
162181 bracket .cleanup_trial (t )
163182
164183 def _cleanup_bracket (self , trial_runner , bracket ):
165- """Cleans up bracket after bracket is completely finished."""
184+ """Cleans up bracket after bracket is completely finished.
185+ Lets the last trial continue to run until termination condition
186+ kicks in."""
166187 for trial in bracket .current_trials ():
167- self ._cleanup_trial (
168- trial_runner , trial , bracket ,
169- hard = (trial .status == Trial .PAUSED ))
188+ if (trial .status == Trial .PAUSED ):
189+ self ._cleanup_trial (
190+ trial_runner , trial , bracket ,
191+ hard = True )
170192
171193 def on_trial_complete (self , trial_runner , trial , result ):
172194 """Cleans up trial info from bracket if trial completed early."""
@@ -219,12 +241,15 @@ class Bracket():
219241
220242 Also keeps track of progress to ensure good scheduling.
221243 """
222- def __init__ (self , max_trials , init_iters , eta , s ):
223- self ._live_trials = {} # stores (result, itrs left before halving)
244+ def __init__ (self , time_attr , max_trials , init_t_attr , eta , s ):
245+ self ._live_trials = {} # maps trial -> current result
224246 self ._all_trials = []
247+ self ._time_attr = time_attr # attribute to
248+
225249 self ._n = self ._n0 = max_trials
226- self ._r = self ._r0 = init_iters
250+ self ._r = self ._r0 = init_t_attr
227251 self ._cumul_r = self ._r0
252+
228253 self ._eta = eta
229254 self ._halves = s
230255
@@ -237,15 +262,15 @@ def add_trial(self, trial):
237262 At a later iteration, a newly added trial will be given equal
238263 opportunity to catch up."""
239264 assert not self .filled (), "Cannot add trial to filled bracket!"
240- self ._live_trials [trial ] = ( None , self . _cumul_r )
265+ self ._live_trials [trial ] = None
241266 self ._all_trials .append (trial )
242267
243268 def cur_iter_done (self ):
244269 """Checks if all iterations have completed.
245270
246271 TODO(rliaw): also check that `t.iterations == self._r`"""
247- all_done = all (itr == 0 for _ , itr in self ._live_trials . values ())
248- return all_done
272+ return all (self . _get_result_time ( result ) >= self ._cumul_r
273+ for result in self . _live_trials . values ())
249274
250275 def finished (self ):
251276 return self ._halves == 0 and self .cur_iter_done ()
@@ -254,8 +279,8 @@ def current_trials(self):
254279 return list (self ._live_trials )
255280
256281 def continue_trial (self , trial ):
257- _ , itr = self ._live_trials [trial ]
258- if itr > 0 :
282+ result = self ._live_trials [trial ]
283+ if self . _get_result_time ( result ) < self . _cumul_r :
259284 return True
260285 else :
261286 return False
@@ -265,24 +290,19 @@ def filled(self):
265290 minimizing the need to backtrack and bookkeep previous medians"""
266291 return len (self ._live_trials ) == self ._n
267292
268- def successive_halving (self ):
293+ def successive_halving (self , reward_attr ):
269294 assert self ._halves > 0
270295 self ._halves -= 1
271296 self ._n /= self ._eta
272297 self ._n = int (np .ceil (self ._n ))
273298 self ._r *= self ._eta
274- self ._r = int (np . ceil (self ._r ))
299+ self ._r = int ((self ._r ))
275300 self ._cumul_r += self ._r
276301 sorted_trials = sorted (
277302 self ._live_trials ,
278- key = lambda t : self ._live_trials [t ][ 0 ]. episode_reward_mean )
303+ key = lambda t : getattr ( self ._live_trials [t ], reward_attr ) )
279304
280305 good , bad = sorted_trials [- self ._n :], sorted_trials [:- self ._n ]
281-
282- # reset good trials to track updated iterations
283- for t in good :
284- res , old_itr = self ._live_trials [t ]
285- self ._live_trials [t ] = (res , self ._r )
286306 return good , bad
287307
288308 def update_trial_stats (self , trial , result ):
@@ -293,10 +313,13 @@ def update_trial_stats(self, trial, result):
293313 in and make sure they're not set as pending later."""
294314
295315 assert trial in self ._live_trials
296- _ , itr = self ._live_trials [trial ]
297- assert itr > 0
298- self ._live_trials [trial ] = (result , itr - 1 )
299- self ._completed_progress += 1
316+ assert self ._get_result_time (result ) >= 0
317+
318+ delta = self ._get_result_time (result ) - \
319+ self ._get_result_time (self ._live_trials [trial ])
320+ assert delta >= 0
321+ self ._completed_progress += delta
322+ self ._live_trials [trial ] = result
300323
301324 def cleanup_trial (self , trial ):
302325 """Clean up statistics tracking for terminated trials (either by force
@@ -315,13 +338,19 @@ def completion_percentage(self):
315338 are dropped."""
316339 return self ._completed_progress / self ._total_work
317340
341+ def _get_result_time (self , result ):
342+ if result is None :
343+ return 0
344+ return getattr (result , self ._time_attr )
345+
318346 def _calculate_total_work (self , n , r , s ):
319347 work = 0
320348 for i in range (s + 1 ):
321349 work += int (n ) * int (r )
322350 n /= self ._eta
323351 n = int (np .ceil (n ))
324352 r *= self ._eta
353+ r = int (r )
325354 return work
326355
327356 def __repr__ (self ):
0 commit comments