Skip to content

Commit 67c52d5

Browse files
committed
release save/load CKPT and A3C Continuous Action Space Example
1 parent ccc4e44 commit 67c52d5

File tree

3 files changed

+436
-67
lines changed

3 files changed

+436
-67
lines changed

docs/modules/files.rst

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ sake of cross-platform.
4444
load_flickr1M_dataset
4545

4646
save_npz
47-
save_npz_dict
4847
load_npz
49-
load_npz_dict
5048
assign_params
51-
load_and_assign_npz
49+
load_and_assign_npz
50+
save_npz_dict
51+
load_npz_dict
52+
save_ckpt
53+
load_ckpt
54+
5255

5356
save_any_to_npy
5457
load_npy_to_any
@@ -114,25 +117,37 @@ Save network into list (npz)
114117
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115118
.. autofunction:: save_npz
116119

120+
Load network from list (npz)
121+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
122+
.. autofunction:: load_npz
123+
124+
Assign a list of parameters to network
125+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126+
.. autofunction:: assign_params
127+
128+
Load and assign a list of parameters to network
129+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130+
.. autofunction:: load_and_assign_npz
131+
132+
117133
Save network into dict (npz)
118134
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
119135
.. autofunction:: save_npz_dict
120136

121-
Load network from save_npz
122-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
123-
.. autofunction:: load_npz
124-
125-
Load network from save_npz_dict
137+
Load network from dict (npz)
126138
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
127139
.. autofunction:: load_npz_dict
128140

129-
Assign parameters to network
141+
142+
Save network into ckpt
130143
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
131-
.. autofunction:: assign_params
144+
.. autofunction:: save_ckpt
145+
146+
Load network from ckpt
147+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
.. autofunction:: load_ckpt
149+
132150

133-
Load and assign parameters to network
134-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
135-
.. autofunction:: load_and_assign_npz
136151

137152
Load and save variables
138153
------------------------
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""
2+
Asynchronous Advantage Actor Critic (A3C) with Continuous Action Space.
3+
4+
Actor Critic History
5+
----------------------
6+
A3C > DDPG (for continuous action space) > AC
7+
8+
Advantage
9+
----------
10+
Training faster and more stable than AC.
11+
12+
Disadvantage
13+
-------------
14+
Have bias.
15+
16+
Reference
17+
----------
18+
MorvanZhou's tutorial: https://morvanzhou.github.io/tutorials/
19+
MorvanZhou's code: https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/experiments/Solve_BipedalWalker/A3C.py
20+
21+
Environment
22+
-----------
23+
BipedalWalker-v2 : https://gym.openai.com/envs/BipedalWalker-v2
24+
25+
Reward is given for moving forward, total 300+ points up to the far end.
26+
If the robot falls, it gets -100. Applying motor torque costs a small amount of
27+
points, more optimal agent will get better score. State consists of hull angle
28+
speed, angular velocity, horizontal speed, vertical speed, position of joints
29+
and joints angular speed, legs contact with ground, and 10 lidar rangefinder
30+
measurements. There's no coordinates in the state vector.
31+
"""
32+
33+
import multiprocessing, threading, gym, os, shutil
34+
import tensorflow as tf
35+
import tensorlayer as tl
36+
from tensorlayer.layers import *
37+
import numpy as np
38+
39+
GAME = 'BipedalWalker-v2' # BipedalWalkerHardcore-v2
40+
OUTPUT_GRAPH = False
41+
LOG_DIR = './log'
42+
N_WORKERS = multiprocessing.cpu_count()
43+
# N_WORKERS = 4
44+
MAX_GLOBAL_EP = 20000#8000
45+
GLOBAL_NET_SCOPE = 'Global_Net'
46+
UPDATE_GLOBAL_ITER = 10
47+
GAMMA = 0.999
48+
ENTROPY_BETA = 0.005
49+
LR_A = 0.00002 # learning rate for actor
50+
LR_C = 0.0001 # learning rate for critic
51+
GLOBAL_RUNNING_R = []
52+
GLOBAL_EP = 0 # will increase during training, stop training when it >= MAX_GLOBAL_EP
53+
54+
env = gym.make(GAME)
55+
56+
N_S = env.observation_space.shape[0]
57+
N_A = env.action_space.shape[0]
58+
A_BOUND = [env.action_space.low, env.action_space.high]
59+
# print(env.unwrapped.hull.position[0])
60+
# exit()
61+
62+
class ACNet(object):
63+
def __init__(self, scope, globalAC=None):
64+
65+
self.scope = scope
66+
if scope == GLOBAL_NET_SCOPE:
67+
## global network only do inference
68+
with tf.variable_scope(scope):
69+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
70+
self._build_net()
71+
self.a_params = tl.layers.get_variables_with_name(scope + '/actor', True, False)
72+
self.c_params = tl.layers.get_variables_with_name(scope + '/critic', True, False)
73+
74+
normal_dist = tf.contrib.distributions.Normal(self.mu, self.sigma) # for continuous action space
75+
76+
with tf.name_scope('choose_a'): # use local params to choose action
77+
self.A = tf.clip_by_value(tf.squeeze(normal_dist.sample(1), axis=0), *A_BOUND)
78+
79+
else:
80+
## worker network calculate gradient locally, update on global network
81+
with tf.variable_scope(scope):
82+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
83+
self.a_his = tf.placeholder(tf.float32, [None, N_A], 'A')
84+
self.v_target = tf.placeholder(tf.float32, [None, 1], 'Vtarget')
85+
86+
self._build_net()
87+
88+
td = tf.subtract(self.v_target, self.v, name='TD_error')
89+
with tf.name_scope('c_loss'):
90+
self.c_loss = tf.reduce_mean(tf.square(td))
91+
92+
with tf.name_scope('wrap_a_out'):
93+
self.test = self.sigma[0]
94+
self.mu, self.sigma = self.mu * A_BOUND[1], self.sigma + 1e-5
95+
96+
normal_dist = tf.contrib.distributions.Normal(self.mu, self.sigma) # for continuous action space
97+
98+
with tf.name_scope('a_loss'):
99+
log_prob = normal_dist.log_prob(self.a_his)
100+
exp_v = log_prob * td
101+
entropy = normal_dist.entropy() # encourage exploration
102+
self.exp_v = ENTROPY_BETA * entropy + exp_v
103+
self.a_loss = tf.reduce_mean(-self.exp_v)
104+
105+
with tf.name_scope('choose_a'): # use local params to choose action
106+
self.A = tf.clip_by_value(tf.squeeze(normal_dist.sample(1), axis=0), *A_BOUND)
107+
108+
with tf.name_scope('local_grad'):
109+
self.a_params = tl.layers.get_variables_with_name(scope + '/actor', True, False)
110+
self.c_params = tl.layers.get_variables_with_name(scope + '/critic', True, False)
111+
self.a_grads = tf.gradients(self.a_loss, self.a_params)
112+
self.c_grads = tf.gradients(self.c_loss, self.c_params)
113+
114+
with tf.name_scope('sync'):
115+
with tf.name_scope('pull'):
116+
self.pull_a_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.a_params, globalAC.a_params)]
117+
self.pull_c_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.c_params, globalAC.c_params)]
118+
with tf.name_scope('push'):
119+
self.update_a_op = OPT_A.apply_gradients(zip(self.a_grads, globalAC.a_params))
120+
self.update_c_op = OPT_C.apply_gradients(zip(self.c_grads, globalAC.c_params))
121+
122+
def _build_net(self):
123+
w_init = tf.contrib.layers.xavier_initializer()
124+
with tf.variable_scope('actor'):
125+
nn = InputLayer(self.s, name='in')
126+
nn = DenseLayer(nn, n_units=500, act=tf.nn.relu6, W_init=w_init, name='la')
127+
nn = DenseLayer(nn, n_units=300, act=tf.nn.relu6, W_init=w_init, name='la2')
128+
mu = DenseLayer(nn, n_units=N_A, act=tf.nn.tanh, W_init=w_init, name='mu')
129+
sigma = DenseLayer(nn, n_units=N_A, act=tf.nn.softplus, W_init=w_init, name='sigma')
130+
self.mu = mu.outputs
131+
self.sigma = sigma.outputs
132+
133+
with tf.variable_scope('critic'):
134+
nn = InputLayer(self.s, name='in')
135+
nn = DenseLayer(nn, n_units=500, act=tf.nn.relu6, W_init=w_init, name='lc')
136+
nn = DenseLayer(nn, n_units=200, act=tf.nn.relu6, W_init=w_init, name='lc2')
137+
v = DenseLayer(nn, n_units=1, W_init=w_init, name='v')
138+
self.v = v.outputs
139+
140+
def update_global(self, feed_dict): # run by a local
141+
_, _, t = sess.run([self.update_a_op, self.update_c_op, self.test], feed_dict) # local grads applies to global net
142+
return t
143+
144+
def pull_global(self): # run by a local
145+
sess.run([self.pull_a_params_op, self.pull_c_params_op])
146+
147+
def choose_action(self, s): # run by a local
148+
s = s[np.newaxis, :]
149+
return sess.run(self.A, {self.s: s})[0]
150+
151+
def save_ckpt(self):
152+
tl.files.save_ckpt(sess=sess, mode_name='model.ckpt', var_list=self.a_params+self.c_params, save_dir=self.scope, printable=True)
153+
154+
def load_ckpt(self):
155+
tl.files.load_ckpt(sess=sess, var_list=self.a_params+self.c_params, save_dir=self.scope, printable=True)
156+
# tl.files.load_ckpt(sess=sess, mode_name='model.ckpt', var_list=self.a_params+self.c_params, save_dir=self.scope, is_latest=False, printable=True)
157+
158+
class Worker(object):
159+
def __init__(self, name, globalAC):
160+
self.env = gym.make(GAME)
161+
self.name = name
162+
self.AC = ACNet(name, globalAC)
163+
164+
def work(self):
165+
global GLOBAL_RUNNING_R, GLOBAL_EP
166+
total_step = 1
167+
buffer_s, buffer_a, buffer_r = [], [], []
168+
while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
169+
s = self.env.reset()
170+
ep_r = 0
171+
while True:
172+
## visualize Workder_0 during training
173+
if self.name == 'Workder_0' and total_step % 30 == 0:
174+
self.env.render()
175+
a = self.AC.choose_action(s)
176+
s_, r, done, info = self.env.step(a)
177+
178+
## set robot falls reward to -2 instead of -100
179+
if r == -100: r = -2
180+
181+
ep_r += r
182+
buffer_s.append(s)
183+
buffer_a.append(a)
184+
buffer_r.append(r)
185+
186+
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
187+
if done:
188+
v_s_ = 0 # terminal
189+
else:
190+
v_s_ = sess.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0]
191+
buffer_v_target = []
192+
for r in buffer_r[::-1]: # reverse buffer r
193+
v_s_ = r + GAMMA * v_s_
194+
buffer_v_target.append(v_s_)
195+
buffer_v_target.reverse()
196+
197+
buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)
198+
feed_dict = {
199+
self.AC.s: buffer_s,
200+
self.AC.a_his: buffer_a,
201+
self.AC.v_target: buffer_v_target,
202+
}
203+
## update gradients on global network
204+
test = self.AC.update_global(feed_dict)
205+
buffer_s, buffer_a, buffer_r = [], [], []
206+
207+
## update local network from global network
208+
self.AC.pull_global()
209+
210+
s = s_
211+
total_step += 1
212+
if done:
213+
if len(GLOBAL_RUNNING_R) == 0: # record running episode reward
214+
GLOBAL_RUNNING_R.append(ep_r)
215+
else:
216+
GLOBAL_RUNNING_R.append(0.95 * GLOBAL_RUNNING_R[-1] + 0.05 * ep_r)
217+
print(
218+
self.name,
219+
"episode:", GLOBAL_EP,
220+
"| pos: %i" % self.env.unwrapped.hull.position[0], # number of move
221+
'| reward: %.1f' % ep_r,
222+
"| running_reward: %.1f" % GLOBAL_RUNNING_R[-1],
223+
# '| sigma:', test, # debug
224+
'WIN '*5 if self.env.unwrapped.hull.position[0] >= 88 else '',
225+
)
226+
GLOBAL_EP += 1
227+
break
228+
229+
if __name__ == "__main__":
230+
sess = tf.Session()
231+
232+
###============================= TRAINING ===============================###
233+
with tf.device("/cpu:0"):
234+
OPT_A = tf.train.RMSPropOptimizer(LR_A, name='RMSPropA')
235+
OPT_C = tf.train.RMSPropOptimizer(LR_C, name='RMSPropC')
236+
GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE) # we only need its params
237+
workers = []
238+
# Create worker
239+
for i in range(N_WORKERS):
240+
i_name = 'Worker_%i' % i # worker name
241+
workers.append(Worker(i_name, GLOBAL_AC))
242+
243+
COORD = tf.train.Coordinator()
244+
tl.layers.initialize_global_variables(sess)
245+
246+
## start TF threading
247+
worker_threads = []
248+
for worker in workers:
249+
job = lambda: worker.work()
250+
t = threading.Thread(target=job)
251+
t.start()
252+
worker_threads.append(t)
253+
COORD.join(worker_threads)
254+
255+
GLOBAL_AC.save_ckpt()
256+
257+
###============================= EVALUATION =============================###
258+
env = gym.make(GAME)
259+
GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)
260+
tl.layers.initialize_global_variables(sess)
261+
GLOBAL_AC.load_ckpt()
262+
while True:
263+
s = env.reset()
264+
rall = 0
265+
while True:
266+
env.render()
267+
a = GLOBAL_AC.choose_action(s)
268+
s, r, d, _ = env.step(a)
269+
rall += r
270+
if d:
271+
print("reward", rall)
272+
break

0 commit comments

Comments
 (0)