Skip to content

Commit e066bcf

Browse files
elibolrobertnishihara
authored andcommitted
Synchronous parameter server example. (#1220)
* Synchronous parameter server example. * Added sync parameter server example to documentation index. * Consolidate documentation and minor simplifications. * Fix linting.
1 parent 428858c commit e066bcf

File tree

4 files changed

+165
-13
lines changed

4 files changed

+165
-13
lines changed

doc/source/example-parameter-server.rst

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
Parameter Server
22
================
33

4-
This document walks through how to implement a simple parameter server example
5-
using actors. To run the application, first install some dependencies.
4+
This document walks through how to implement simple synchronous and asynchronous
5+
parameter servers using actors. To run the application, first install some
6+
dependencies.
67

78
.. code-block:: bash
89
@@ -12,17 +13,24 @@ You can view the `code for this example`_.
1213

1314
.. _`code for this example`: https://github.com/ray-project/ray/tree/master/examples/parameter_server
1415

15-
The example can be run as follows.
16+
The examples can be run as follows.
1617

1718
.. code-block:: bash
1819
19-
python ray/examples/parameter_server/parameter_server.py --num-workers=4
20+
# Run the asynchronous parameter server.
21+
python ray/examples/parameter_server/async_parameter_server.py --num-workers=4
22+
23+
# Run the synchronous parameter server.
24+
python ray/examples/parameter_server/sync_parameter_server.py --num-workers=4
2025
2126
Note that this examples uses distributed actor handles, which are still
2227
considered experimental.
2328

24-
The parameter server itself is implemented as an actor, which exposes the
25-
methods ``push`` and ``pull``.
29+
Asynchronous Parameter Server
30+
-----------------------------
31+
32+
The asynchronous parameter server itself is implemented as an actor, which
33+
exposes the methods ``push`` and ``pull``.
2634

2735
.. code-block:: python
2836
@@ -62,3 +70,58 @@ Then we can create a parameter server and initiate training as follows.
6270
6371
ps = ParameterServer.remote(keys, initial_values)
6472
worker_tasks = [worker_task.remote(ps) for _ in range(4)]
73+
74+
Synchronous Parameter Server
75+
----------------------------
76+
77+
The parameter server is implemented as an actor, which exposes the
78+
methods ``apply_gradients`` and ``get_weights``. A constant linear scaling
79+
rule is applied by scaling the learning rate by the number of workers.
80+
81+
.. code-block:: python
82+
83+
@ray.remote
84+
class ParameterServer(object):
85+
def __init__(self, learning_rate):
86+
self.net = model.SimpleCNN(learning_rate=learning_rate)
87+
88+
def apply_gradients(self, *gradients):
89+
self.net.apply_gradients(np.mean(gradients, axis=0))
90+
return self.net.variables.get_flat()
91+
92+
def get_weights(self):
93+
return self.net.variables.get_flat()
94+
95+
96+
Workers are actors which expose the method ``compute_gradients``.
97+
98+
.. code-block:: python
99+
100+
@ray.remote
101+
class Worker(object):
102+
def __init__(self, worker_index, batch_size=50):
103+
self.worker_index = worker_index
104+
self.batch_size = batch_size
105+
self.mnist = input_data.read_data_sets("MNIST_data", one_hot=True,
106+
seed=worker_index)
107+
self.net = model.SimpleCNN()
108+
109+
def compute_gradients(self, weights):
110+
self.net.variables.set_flat(weights)
111+
xs, ys = self.mnist.train.next_batch(self.batch_size)
112+
return self.net.compute_gradients(xs, ys)
113+
114+
Training alternates between computing the gradients given the current weights
115+
from the parameter server and updating the parameter server's weights with the
116+
resulting gradients.
117+
118+
.. code-block:: python
119+
120+
while True:
121+
gradients = [worker.compute_gradients.remote(current_weights)
122+
for worker in workers]
123+
current_weights = ps.apply_gradients.remote(*gradients)
124+
125+
Both of these examples implement the parameter server using a single actor,
126+
however they can be easily extended to **shard the parameters across multiple
127+
actors**.

examples/parameter_server/parameter_server.py renamed to examples/parameter_server/async_parameter_server.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import model
1212

13-
parser = argparse.ArgumentParser(description="Run the parameter server "
14-
"example.")
13+
parser = argparse.ArgumentParser(description="Run the asynchronous parameter "
14+
"server example.")
1515
parser.add_argument("--num-workers", default=4, type=int,
1616
help="The number of workers to use.")
1717
parser.add_argument("--redis-address", default=None, type=str,
@@ -35,10 +35,9 @@ def pull(self, keys):
3535

3636

3737
@ray.remote
38-
def worker_task(ps):
38+
def worker_task(ps, batch_size=50):
3939
# Download MNIST.
4040
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
41-
batch_size = 50
4241

4342
# Initialize the model.
4443
net = model.SimpleCNN()
@@ -55,7 +54,7 @@ def worker_task(ps):
5554
ps.push.remote(keys, gradients)
5655

5756

58-
if __name__ == '__main__':
57+
if __name__ == "__main__":
5958
args = parser.parse_args()
6059

6160
ray.init(redis_address=args.redis_address)

examples/parameter_server/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class SimpleCNN(object):
14-
def __init__(self):
14+
def __init__(self, learning_rate=1e-4):
1515
with tf.Graph().as_default():
1616

1717
# Create the model
@@ -29,7 +29,7 @@ def __init__(self):
2929
self.cross_entropy = tf.reduce_mean(cross_entropy)
3030

3131
with tf.name_scope('adam_optimizer'):
32-
self.optimizer = tf.train.AdamOptimizer(1e-4)
32+
self.optimizer = tf.train.AdamOptimizer(learning_rate)
3333
self.train_step = self.optimizer.minimize(
3434
self.cross_entropy)
3535

@@ -51,6 +51,11 @@ def __init__(self):
5151

5252
self.grads = self.optimizer.compute_gradients(
5353
self.cross_entropy)
54+
self.grads_placeholder = [
55+
(tf.placeholder("float", shape=grad[1].get_shape()), grad[1])
56+
for grad in self.grads]
57+
self.apply_grads_placeholder = self.optimizer.apply_gradients(
58+
self.grads_placeholder)
5459

5560
def compute_update(self, x, y):
5661
# TODO(rkn): Computing the weights before and after the training step
@@ -68,6 +73,12 @@ def compute_gradients(self, x, y):
6873
self.y_: y,
6974
self.keep_prob: 0.5})
7075

76+
def apply_gradients(self, gradients):
77+
feed_dict = {}
78+
for i in range(len(self.grads_placeholder)):
79+
feed_dict[self.grads_placeholder[i][0]] = gradients[i]
80+
self.sess.run(self.apply_grads_placeholder, feed_dict=feed_dict)
81+
7182
def compute_accuracy(self, x, y):
7283
return self.sess.run(self.accuracy,
7384
feed_dict={self.x: x,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import argparse
6+
7+
import numpy as np
8+
from tensorflow.examples.tutorials.mnist import input_data
9+
10+
import ray
11+
import model
12+
13+
parser = argparse.ArgumentParser(description="Run the synchronous parameter "
14+
"server example.")
15+
parser.add_argument("--num-workers", default=4, type=int,
16+
help="The number of workers to use.")
17+
parser.add_argument("--redis-address", default=None, type=str,
18+
help="The Redis address of the cluster.")
19+
20+
21+
@ray.remote
22+
class ParameterServer(object):
23+
def __init__(self, learning_rate):
24+
self.net = model.SimpleCNN(learning_rate=learning_rate)
25+
26+
def apply_gradients(self, *gradients):
27+
self.net.apply_gradients(np.mean(gradients, axis=0))
28+
return self.net.variables.get_flat()
29+
30+
def get_weights(self):
31+
return self.net.variables.get_flat()
32+
33+
34+
@ray.remote
35+
class Worker(object):
36+
def __init__(self, worker_index, batch_size=50):
37+
self.worker_index = worker_index
38+
self.batch_size = batch_size
39+
self.mnist = input_data.read_data_sets("MNIST_data", one_hot=True,
40+
seed=worker_index)
41+
self.net = model.SimpleCNN()
42+
43+
def compute_gradients(self, weights):
44+
self.net.variables.set_flat(weights)
45+
xs, ys = self.mnist.train.next_batch(self.batch_size)
46+
return self.net.compute_gradients(xs, ys)
47+
48+
49+
if __name__ == "__main__":
50+
args = parser.parse_args()
51+
52+
ray.init(redis_address=args.redis_address)
53+
54+
# Create a parameter server.
55+
net = model.SimpleCNN()
56+
ps = ParameterServer.remote(1e-4 * args.num_workers)
57+
58+
# Create workers.
59+
workers = [Worker.remote(worker_index)
60+
for worker_index in range(args.num_workers)]
61+
62+
# Download MNIST.
63+
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
64+
65+
i = 0
66+
current_weights = ps.get_weights.remote()
67+
while True:
68+
# Compute and apply gradients.
69+
gradients = [worker.compute_gradients.remote(current_weights)
70+
for worker in workers]
71+
current_weights = ps.apply_gradients.remote(*gradients)
72+
73+
if i % 10 == 0:
74+
# Evaluate the current model.
75+
net.variables.set_flat(ray.get(current_weights))
76+
test_xs, test_ys = mnist.test.next_batch(1000)
77+
accuracy = net.compute_accuracy(test_xs, test_ys)
78+
print("Iteration {}: accuracy is {}".format(i, accuracy))
79+
i += 1

0 commit comments

Comments
 (0)