@@ -60,11 +60,9 @@ def __init__(self, max_iter=200, eta=3):
6060
6161 def on_trial_add (self , trial_runner , trial ):
6262 """On a new trial add, if current bracket is not filled,
63- add to current bracket. Else, if current hp iteration is not filled,
63+ add to current bracket. Else, if current band is not filled,
6464 create new bracket, add to current bracket.
65- Else, create new iteration, create new bracket, add to bracket.
66-
67- TODO(rliaw): This is messy."""
65+ Else, create new iteration, create new bracket, add to bracket."""
6866
6967 cur_bracket = self ._state ["bracket" ]
7068 cur_band = self ._hyperbands [self ._state ["band_idx" ]]
@@ -76,9 +74,9 @@ def on_trial_add(self, trial_runner, trial):
7674 self ._hyperbands .append (cur_band )
7775 self ._state ["band_idx" ] += 1
7876
79- # cur_band will always be less than s_max or else filled
80- s = self . _s_max_1 - len (cur_band ) - 1
81- assert s >= 0 , "Current band is filled but adding bracket !"
77+ # cur_band will always be less than s_max_1 or else filled
78+ s = len (cur_band )
79+ assert s < self . _s_max_1 , "Current band is filled!"
8280
8381 # create new bracket
8482 cur_bracket = Bracket (self ._get_n0 (s ),
@@ -102,73 +100,87 @@ def on_trial_result(self, trial_runner, trial, result):
102100
103101 If a given trial finishes and bracket iteration is not done,
104102 the trial will be paused and resources will be given up.
105- When bracket iteration is done, Trials will be successively halved,
106- and during each halving phase, bad trials will be stopped while good
107- trials will return to "PENDING". This scheduler will not start trials
108- but will stop trials. The current running trial will not be handled,
103+
104+ This scheduler will not start trials but will stop trials.
105+ The current running trial will not be handled,
109106 as the trialrunner will be given control to handle it.
110107
111108 # TODO(rliaw) should be only called if trial has not errored"""
112109 bracket , _ = self ._trial_info [trial ]
113110 bracket .update_trial_stats (trial , result )
111+
114112 if bracket .continue_trial (trial ):
115113 return TrialScheduler .CONTINUE
116114
117- signal = TrialScheduler .PAUSE
115+ action = self ._process_bracket (trial_runner , bracket , trial )
116+ return action
117+
118+ def _process_bracket (self , trial_runner , bracket , trial ):
119+ """This is called whenever a trial makes progress.
120+
121+ When all live trials in the bracket have no more iterations left,
122+ Trials will be successively halved. If bracket is done, all
123+ non-running trials will be stopped and cleaned up,
124+ and during each halving phase, bad trials will be stopped while good
125+ trials will return to "PENDING"."""
118126
127+ action = TrialScheduler .PAUSE
119128 if bracket .cur_iter_done ():
120129 if bracket .finished ():
121130 self ._cleanup_bracket (trial_runner , bracket )
122131 return TrialScheduler .STOP
123- # what if bracket is done and trial not completed?
132+
124133 good , bad = bracket .successive_halving ()
125134 # kill bad trials
126135 for t in bad :
127- self ._num_stopped += 1
128136 if t .status == Trial .PAUSED :
129- trial_runner . _stop_trial ( t )
130- bracket . cleanup_trial_early ( t )
131- elif t is trial :
132- signal = TrialScheduler .STOP
137+ self . _cleanup_trial ( trial_runner , t , bracket , hard = True )
138+ elif t . status == Trial . RUNNING :
139+ self . _cleanup_trial ( trial_runner , t , bracket , hard = False )
140+ action = TrialScheduler .STOP
133141 else :
134142 raise Exception ("Trial with unexpected status encountered" )
135143
136144 # ready the good trials
137145 for t in good :
138146 if t .status == Trial .PAUSED :
139147 t .unpause ()
140- elif t is trial :
141- signal = TrialScheduler .CONTINUE
148+ elif t . status == Trial . RUNNING :
149+ action = TrialScheduler .CONTINUE
142150 else :
143151 raise Exception ("Trial with unexpected status encountered" )
152+ return action
144153
145- return signal
154+ def _cleanup_trial (self , trial_runner , t , bracket , hard = False ):
155+ """Bookkeeping for trials finished. If `hard=True`, then
156+ this scheduler will force the trial_runner to release resources.
146157
147- def _cleanup_bracket (self , trial_runner , bracket ):
148- """Cleans up bracket after bracket is completely finished.
158+ Otherwise, only clean up trial information locally."""
159+ self ._num_stopped += 1
160+ if hard :
161+ trial_runner ._stop_trial (t )
162+ bracket .cleanup_trial (t )
149163
150- Bracket information will only be cleaned up after the trialrunner has
151- finished its bookkeeping ."""
152- for t in bracket .current_trials ():
153- if t . status == Trial . PAUSED :
154- trial_runner . _stop_trial ( t )
155- bracket . cleanup_trial_early ( t )
164+ def _cleanup_bracket ( self , trial_runner , bracket ):
165+ """Cleans up bracket after bracket is completely finished ."""
166+ for trial in bracket .current_trials ():
167+ self . _cleanup_trial (
168+ trial_runner , trial , bracket ,
169+ hard = ( trial . status == Trial . PAUSED ) )
156170
157171 def on_trial_complete (self , trial_runner , trial , result ):
158- """Cleans up trial info from bracket if trial completed early.
172+ """Cleans up trial info from bracket if trial completed early."""
159173
160- Bracket information will only be cleaned up after the trialrunner has
161- finished its bookkeeping."""
162174 bracket , _ = self ._trial_info [trial ]
163- bracket .cleanup_trial_early (trial )
175+ self ._cleanup_trial (trial_runner , trial , bracket , hard = False )
176+ self ._process_bracket (trial_runner , bracket , trial )
164177
165178 def on_trial_error (self , trial_runner , trial ):
166- """Cleans up trial info from bracket if trial errored early.
179+ """Cleans up trial info from bracket if trial errored early."""
167180
168- Bracket information will only be cleaned up after the trialrunner has
169- finished its bookkeeping."""
170181 bracket , _ = self ._trial_info [trial ]
171- bracket .cleanup_trial_early (trial )
182+ self ._cleanup_trial (trial_runner , trial , bracket , hard = False )
183+ self ._process_bracket (trial_runner , bracket , trial )
172184
173185 def choose_trial_to_run (self , trial_runner , * args ):
174186 """Fair scheduling within iteration by completion percentage.
@@ -177,6 +189,7 @@ def choose_trial_to_run(self, trial_runner, *args):
177189
178190 If iteration is occupied (ie, no trials to run), then look into
179191 next iteration."""
192+
180193 for hyperband in self ._hyperbands :
181194 for bracket in sorted (hyperband ,
182195 key = lambda b : b .completion_percentage ()):
@@ -187,10 +200,17 @@ def choose_trial_to_run(self, trial_runner, *args):
187200 return None
188201
189202 def debug_string (self ):
203+ brackets = [
204+ "({0}/{1})" .format (
205+ len (bracket ._live_trials ), len (bracket ._all_trials ))
206+ for band in self ._hyperbands for bracket in band ]
190207 return " " .join ([
191208 "Using HyperBand:" ,
192209 "num_stopped={}" .format (self ._num_stopped ),
193- "brackets={}" .format (sum (len (band ) for band in self ._hyperbands ))])
210+ "total_brackets={}" .format (
211+ sum (len (band ) for band in self ._hyperbands )),
212+ " " .join (brackets )
213+ ])
194214
195215
196216class Bracket ():
@@ -278,8 +298,9 @@ def update_trial_stats(self, trial, result):
278298 self ._live_trials [trial ] = (result , itr - 1 )
279299 self ._completed_progress += 1
280300
281- def cleanup_trial_early (self , trial ):
282- """Clean up statistics tracking for trial that terminated early.
301+ def cleanup_trial (self , trial ):
302+ """Clean up statistics tracking for terminated trials (either by force
303+ or otherwise).
283304
284305 This may cause bad trials to continue for a long time, in the case
285306 where all the good trials finish early and there are only bad trials
0 commit comments