Skip to content

Commit 0eae917

Browse files
robertnishiharapcmoritz
authored andcommitted
[rllib] Clean up evolution strategies example. (#1225)
* Remove ES observation statistics. * Consolidate policy classes. * Remove random stream. * Move rollout function out of policy. * Consolidate policy initialization. * Replace act implementation with sess.run. * Remove tf_utils. * Remove variable scope. * Remove unused imports. * Use regular TF session. * Use MeanStdFilter. * Minor. * Clarify naming. * Update documentation. * eps -> episodes * Report noiseless evaluation runs. * Clean up naming. * Update documentation. * Fix some bugs. * Make it run on atari. * Don't add action noise during evaluation runs. * Add ES to checkpoint/restore test. * Small cleanups and remove redundant calls to get_weights. * Remove outdated comment.
1 parent eadb998 commit 0eae917

File tree

9 files changed

+239
-721
lines changed

9 files changed

+239
-721
lines changed

doc/source/example-evolution-strategies.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ on the ``Humanoid-v1`` gym environment.
2020
2121
python/ray/rllib/train.py --env=Humanoid-v1 --alg=ES
2222
23+
To train a policy on a cluster (e.g., using 900 workers), run the following.
24+
25+
.. code-block:: bash
26+
27+
python ray/python/ray/rllib/train.py \
28+
--env=Humanoid-v1 \
29+
--alg=ES \
30+
--redis-address=<redis-address> \
31+
--config='{"num_workers": 900, "episodes_per_batch": 10000, "timesteps_per_batch": 100000}'
32+
2333
At the heart of this example, we define a ``Worker`` class. These workers have
2434
a method ``do_rollouts``, which will be used to perform simulate randomly
2535
perturbed policies in a given environment.
@@ -34,14 +44,12 @@ perturbed policies in a given environment.
3444
# Details omitted.
3545
3646
def do_rollouts(self, params):
37-
# Set the network weights.
38-
self.policy.set_trainable_flat(params)
3947
perturbation = # Generate a random perturbation to the policy.
4048
41-
self.policy.set_trainable_flat(params + perturbation)
49+
self.policy.set_weights(params + perturbation)
4250
# Do rollout with the perturbed policy.
4351
44-
self.policy.set_trainable_flat(params - perturbation)
52+
self.policy.set_weights(params - perturbation)
4553
# Do rollout with the perturbed policy.
4654
4755
# Return the rewards.
@@ -60,7 +68,7 @@ and use the rewards from the rollouts to update the policy.
6068
6169
while True:
6270
# Get the current policy weights.
63-
theta = policy.get_trainable_flat()
71+
theta = policy.get_weights()
6472
# Put the current policy weights in the object store.
6573
theta_id = ray.put(theta)
6674
# Use the actors to do rollouts, note that we pass in the ID of the policy

0 commit comments

Comments
 (0)