Skip to content

Commit 7c38f96

Browse files
ericlrichardliaw
authored andcommitted
[tune] Add command line support for choosing early stopping schedulers (#1209)
* command line support * add checkpoint freq * fix other flags * fix * docs * doc
1 parent afdc873 commit 7c38f96

File tree

11 files changed

+218
-101
lines changed

11 files changed

+218
-101
lines changed

python/ray/rllib/agent.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def train(self):
142142
start = time.time()
143143
result = self._train()
144144
self._iteration += 1
145-
time_this_iter = time.time() - start
145+
if result.time_this_iter_s is not None:
146+
time_this_iter = result.time_this_iter_s
147+
else:
148+
time_this_iter = time.time() - start
146149

147150
assert result.timesteps_this_iter is not None
148151

@@ -340,6 +343,30 @@ def get_info(self):
340343
return self.info
341344

342345

346+
class _SigmoidFakeData(_MockAgent):
347+
"""Agent that returns sigmoid learning curves.
348+
349+
This can be helpful for evaluating early stopping algorithms."""
350+
351+
_agent_name = "SigmoidFakeData"
352+
_default_config = {
353+
"width": 100,
354+
"height": 100,
355+
"offset": 0,
356+
"iter_time": 10,
357+
"iter_timesteps": 1,
358+
}
359+
360+
def _train(self):
361+
i = max(0, self.iteration - self.config["offset"])
362+
v = np.tanh(float(i) / self.config["width"])
363+
v *= self.config["height"]
364+
return TrainingResult(
365+
episode_reward_mean=v, episode_len_mean=v,
366+
timesteps_this_iter=self.config["iter_timesteps"],
367+
time_this_iter_s=self.config["iter_time"], info={})
368+
369+
343370
def get_agent_class(alg):
344371
"""Returns the class of an known agent given its name."""
345372

@@ -360,6 +387,8 @@ def get_agent_class(alg):
360387
return script_runner.ScriptRunner
361388
elif alg == "__fake":
362389
return _MockAgent
390+
elif alg == "__sigmoid_fake_data":
391+
return _SigmoidFakeData
363392
else:
364393
raise Exception(
365394
("Unknown algorithm {}, check --alg argument. Valid choices " +

python/ray/rllib/train.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010

1111
from ray.tune.config_parser import make_parser, resources_to_json
12-
from ray.tune.tune import run_experiments
12+
from ray.tune.tune import make_scheduler, run_experiments
1313

1414

1515
EXAMPLE_USAGE = """
@@ -18,6 +18,8 @@
1818
1919
Grid search example:
2020
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
21+
22+
Note that -f overrides all other trial-specific command-line options.
2123
"""
2224

2325

@@ -33,6 +35,8 @@
3335
help="Number of CPUs to allocate to Ray.")
3436
parser.add_argument("--num-gpus", default=None, type=int,
3537
help="Number of GPUs to allocate to Ray.")
38+
parser.add_argument("--experiment-name", default="default", type=str,
39+
help="Name of experiment dir.")
3640
parser.add_argument("-f", "--config-file", default=None, type=str,
3741
help="If specified, use config options from this file.")
3842

@@ -43,15 +47,19 @@
4347
with open(args.config_file) as f:
4448
experiments = yaml.load(f)
4549
else:
50+
# Note: keep this in sync with tune/config_parser.py
4651
experiments = {
47-
"default": { # i.e. log to /tmp/ray/default
52+
args.experiment_name: { # i.e. log to /tmp/ray/default
4853
"alg": args.alg,
54+
"checkpoint_freq": args.checkpoint_freq,
55+
"local_dir": args.local_dir,
4956
"env": args.env,
5057
"resources": resources_to_json(args.resources),
5158
"stop": args.stop,
5259
"config": args.config,
5360
"restore": args.restore,
5461
"repeat": args.repeat,
62+
"upload_dir": args.upload_dir,
5563
}
5664
}
5765

@@ -62,5 +70,6 @@
6270
parser.error("the following arguments are required: --env")
6371

6472
run_experiments(
65-
experiments, redis_address=args.redis_address,
73+
experiments, scheduler=make_scheduler(args),
74+
redis_address=args.redis_address,
6675
num_cpus=args.num_cpus, num_gpus=args.num_gpus)

python/ray/tune/config_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def make_parser(**kwargs):
3131

3232
parser = argparse.ArgumentParser(**kwargs)
3333

34+
# Note: keep this in sync with rllib/train.py
3435
parser.add_argument("--alg", default=None, type=str,
3536
help="The learning algorithm to train.")
3637
parser.add_argument("--stop", default="{}", type=json.loads,
@@ -44,10 +45,14 @@ def make_parser(**kwargs):
4445
help="Number of times to repeat each trial.")
4546
parser.add_argument("--local-dir", default="/tmp/ray", type=str,
4647
help="Local dir to save training results to.")
47-
parser.add_argument("--upload-dir", default=None, type=str,
48+
parser.add_argument("--upload-dir", default="", type=str,
4849
help="URI to upload training results to.")
49-
parser.add_argument("--checkpoint-freq", default=None, type=int,
50+
parser.add_argument("--checkpoint-freq", default=0, type=int,
5051
help="How many iterations between checkpoints.")
52+
parser.add_argument("--scheduler", default="FIFO", type=str,
53+
help="FIFO, MedianStopping, or HyperBand")
54+
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
55+
help="Config options to pass to the scheduler.")
5156

5257
# Note: this currently only makes sense when running a single trial
5358
parser.add_argument("--restore", default=None, type=str,

python/ray/tune/hyperband.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class HyperBandScheduler(FIFOScheduler):
3232
and band and will spill over to new brackets/bands accordingly.
3333
"""
3434

35-
def __init__(self, max_iter, eta=3):
35+
def __init__(self, max_iter=200, eta=3):
3636
"""
3737
args:
3838
max_iter (int): maximum iterations per configuration
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import collections
6+
import numpy as np
7+
8+
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
9+
10+
11+
class MedianStoppingRule(FIFOScheduler):
12+
"""Implements the median stopping rule as described in the Vizier paper:
13+
14+
https://research.google.com/pubs/pub46180.html
15+
16+
Args:
17+
time_attr (str): The TrainingResult attr to use for comparing time.
18+
Note that you can pass in something non-temporal such as
19+
`training_iteration` as a measure of progress, the only requirement
20+
is that the attribute should increase monotonically.
21+
reward_attr (str): The TrainingResult objective value attribute. As
22+
with `time_attr`, this may refer to any objective value that
23+
is supposed to increase with time.
24+
grace_period (float): Only stop trials at least this old in time.
25+
The units are the same as the attribute named by `time_attr`.
26+
min_samples_required (int): Min samples to compute median over.
27+
hard_stop (bool): If false, pauses trials instead of stopping
28+
them. When all other trials are complete, paused trials will be
29+
resumed and allowed to run FIFO.
30+
"""
31+
32+
def __init__(
33+
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
34+
grace_period=60.0, min_samples_required=3, hard_stop=True):
35+
FIFOScheduler.__init__(self)
36+
self._stopped_trials = set()
37+
self._completed_trials = set()
38+
self._results = collections.defaultdict(list)
39+
self._grace_period = grace_period
40+
self._min_samples_required = min_samples_required
41+
self._reward_attr = reward_attr
42+
self._time_attr = time_attr
43+
self._hard_stop = hard_stop
44+
45+
def on_trial_result(self, trial_runner, trial, result):
46+
"""Callback for early stopping.
47+
48+
This stopping rule stops a running trial if the trial's best objective
49+
value by step `t` is strictly worse than the median of the running
50+
averages of all completed trials' objectives reported up to step `t`.
51+
"""
52+
53+
if trial in self._stopped_trials:
54+
assert not self._hard_stop
55+
return TrialScheduler.CONTINUE # fall back to FIFO
56+
57+
time = getattr(result, self._time_attr)
58+
self._results[trial].append(result)
59+
median_result = self._get_median_result(time)
60+
best_result = self._best_result(trial)
61+
print("Trial {} best res={} vs median res={} at t={}".format(
62+
trial, best_result, median_result, time))
63+
if best_result < median_result and time > self._grace_period:
64+
print("MedianStoppingRule: early stopping {}".format(trial))
65+
self._stopped_trials.add(trial)
66+
if self._hard_stop:
67+
return TrialScheduler.STOP
68+
else:
69+
return TrialScheduler.PAUSE
70+
else:
71+
return TrialScheduler.CONTINUE
72+
73+
def on_trial_complete(self, trial_runner, trial, result):
74+
self._results[trial].append(result)
75+
self._completed_trials.add(trial)
76+
77+
def debug_string(self):
78+
return "Using MedianStoppingRule: num_stopped={}.".format(
79+
len(self._stopped_trials))
80+
81+
def _get_median_result(self, time):
82+
scores = []
83+
for trial in self._completed_trials:
84+
scores.append(self._running_result(trial, time))
85+
if len(scores) >= self._min_samples_required:
86+
return np.median(scores)
87+
else:
88+
return float('-inf')
89+
90+
def _running_result(self, trial, t_max=float('inf')):
91+
results = self._results[trial]
92+
# TODO(ekl) we could do interpolation to be more precise, but for now
93+
# assume len(results) is large and the time diffs are roughly equal
94+
return np.mean(
95+
[getattr(r, self._reward_attr)
96+
for r in results if getattr(r, self._time_attr) <= t_max])
97+
98+
def _best_result(self, trial):
99+
results = self._results[trial]
100+
return max([getattr(r, self._reward_attr) for r in results])

python/ray/tune/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
# (Auto-filled) Number of timesteps in the simulator in this iteration.
4747
"timesteps_this_iter",
4848

49-
# (Auto-filled) Time in seconds this iteration took to run.
49+
# (Auto-filled) Time in seconds this iteration took to run. This may be
50+
# overriden in order to override the system-computed time difference.
5051
"time_this_iter_s",
5152

5253
# (Auto-filled) Accumulated time in seconds for this entire experiment.

python/ray/tune/trial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Trial(object):
6060
def __init__(
6161
self, env_creator, alg, config={}, local_dir='/tmp/ray',
6262
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
63-
stopping_criterion={}, checkpoint_freq=None,
63+
stopping_criterion={}, checkpoint_freq=0,
6464
restore_path=None, upload_dir=None):
6565
"""Initialize a new trial.
6666
@@ -179,7 +179,7 @@ def should_stop(self, result):
179179
def should_checkpoint(self):
180180
"""Whether this trial is due for checkpointing."""
181181

182-
if self.checkpoint_freq is None:
182+
if not self.checkpoint_freq:
183183
return False
184184

185185
return self.last_result.training_iteration % self.checkpoint_freq == 0

python/ray/tune/trial_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import os
56
import ray
67
import time
78
import traceback
@@ -42,9 +43,21 @@ def __init__(self, scheduler=None):
4243
self._committed_resources = Resources(cpu=0, gpu=0)
4344
self._resources_initialized = False
4445

46+
# For debugging, it may be useful to halt trials after some time has
47+
# elapsed. TODO(ekl) consider exposing this in the API.
48+
self._global_time_limit = float(
49+
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')))
50+
self._total_time = 0
51+
4552
def is_finished(self):
4653
"""Returns whether all trials have finished running."""
4754

55+
if self._total_time > self._global_time_limit:
56+
print(
57+
"Exceeded global time limit {} / {}".format(
58+
self._total_time, self._global_time_limit))
59+
return True
60+
4861
for t in self._trials:
4962
if t.status in [Trial.PENDING, Trial.RUNNING, Trial.PAUSED]:
5063
return False
@@ -148,6 +161,7 @@ def _process_events(self):
148161
result = ray.get(result_id)
149162
print("result", result)
150163
trial.last_result = result
164+
self._total_time += result.time_this_iter_s
151165

152166
if trial.should_stop(result):
153167
self._scheduler_alg.on_trial_complete(self, trial, result)

0 commit comments

Comments
 (0)