Skip to content

Commit c285182

Browse files
committed
优化PPO
优化PPO算法,使其更加稳定。
1 parent a70eb80 commit c285182

2 files changed

Lines changed: 84 additions & 40 deletions

File tree

AquaML/rlalgo/PPO.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -215,51 +215,95 @@ def _optimize_(self):
215215
for idx in self.expand_dims_idx:
216216
actor_obs[idx] = tf.expand_dims(actor_obs[idx], axis=1)
217217

218+
info_list = []
219+
buffer_size = train_actor_input['actor_obs'][0].shape[0]
220+
critic_buffer_size = self.hyper_parameters.buffer_size
221+
critic_batch_steps = self.hyper_parameters.batch_size
222+
218223
for _ in range(self.hyper_parameters.update_times):
219-
# train actor
220-
# TODO: wrap this part into a function
221-
for _ in range(self.hyper_parameters.update_actor_times):
222-
start_index = 0
223-
end_index = 0
224+
# fusion ppo firstly update critic
225+
start_index = 0
226+
end_index = 0
227+
critic_start_index = 0
228+
while end_index < buffer_size:
229+
end_index = min(start_index + self.hyper_parameters.batch_size,
230+
buffer_size)
231+
critic_end_index = min(critic_start_index + critic_batch_steps, critic_buffer_size)
232+
critic_optimize_info_list = []
224233
actor_optimize_info_list = []
225-
while end_index < self.hyper_parameters.buffer_size:
226-
end_index = min(start_index + self.hyper_parameters.batch_size, self.hyper_parameters.buffer_size)
227-
228-
batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
229-
230-
start_index = end_index
234+
batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
235+
batch_train_critic_input = self.get_batch_data(train_critic_input, critic_start_index, critic_end_index)
236+
start_index = end_index
237+
critic_start_index = critic_end_index
238+
for _ in range(self.hyper_parameters.update_critic_times):
239+
critic_optimize_info = self.train_critic(
240+
critic_obs=batch_train_critic_input['critic_obs'],
241+
target=batch_train_critic_input['target'],
242+
)
243+
critic_optimize_info_list.append(critic_optimize_info)
231244

245+
for _ in range(self.hyper_parameters.update_actor_times):
232246
actor_optimize_info = self.train_actor(
233247
actor_obs=batch_train_actor_input['actor_obs'],
234248
advantage=batch_train_actor_input['advantage'],
235249
old_log_prob=batch_train_actor_input['old_log_prob'],
236250
action=batch_train_actor_input['action'],
237-
epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
238-
entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
239251
)
240252
actor_optimize_info_list.append(actor_optimize_info)
253+
critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
241254
actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)
242-
243-
# train critic
244-
for _ in range(self.hyper_parameters.update_critic_times):
245-
start_index = 0
246-
end_index = 0
247-
critic_optimize_info_list = []
248-
for _ in range(self.hyper_parameters.update_critic_times):
249-
while end_index < self.hyper_parameters.buffer_size:
250-
end_index = min(start_index + self.hyper_parameters.batch_size,
251-
self.hyper_parameters.buffer_size)
252-
253-
batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)
254-
255-
start_index = end_index
256-
257-
critic_optimize_info = self.train_critic(
258-
critic_obs=batch_train_critic_input['critic_obs'],
259-
target=batch_train_critic_input['target'],
260-
)
261-
critic_optimize_info_list.append(critic_optimize_info)
262-
critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
263-
264-
return_dict = {**actor_optimize_info, **critic_optimize_info}
265-
return return_dict
255+
info = {**critic_optimize_info, **actor_optimize_info}
256+
info_list.append(info)
257+
258+
info = self.cal_average_batch_dict(info_list)
259+
260+
return info
261+
262+
# for _ in range(self.hyper_parameters.update_times):
263+
# # train actor
264+
# # TODO: wrap this part into a function
265+
# for _ in range(self.hyper_parameters.update_actor_times):
266+
# start_index = 0
267+
# end_index = 0
268+
# actor_optimize_info_list = []
269+
# while end_index < self.hyper_parameters.buffer_size:
270+
# end_index = min(start_index + self.hyper_parameters.batch_size, self.hyper_parameters.buffer_size)
271+
#
272+
# batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
273+
#
274+
# start_index = end_index
275+
#
276+
# actor_optimize_info = self.train_actor(
277+
# actor_obs=batch_train_actor_input['actor_obs'],
278+
# advantage=batch_train_actor_input['advantage'],
279+
# old_log_prob=batch_train_actor_input['old_log_prob'],
280+
# action=batch_train_actor_input['action'],
281+
# epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
282+
# entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
283+
# )
284+
# actor_optimize_info_list.append(actor_optimize_info)
285+
# actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)
286+
#
287+
# # train critic
288+
# for _ in range(self.hyper_parameters.update_critic_times):
289+
# start_index = 0
290+
# end_index = 0
291+
# critic_optimize_info_list = []
292+
# for _ in range(self.hyper_parameters.update_critic_times):
293+
# while end_index < self.hyper_parameters.buffer_size:
294+
# end_index = min(start_index + self.hyper_parameters.batch_size,
295+
# self.hyper_parameters.buffer_size)
296+
#
297+
# batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)
298+
#
299+
# start_index = end_index
300+
#
301+
# critic_optimize_info = self.train_critic(
302+
# critic_obs=batch_train_critic_input['critic_obs'],
303+
# target=batch_train_critic_input['target'],
304+
# )
305+
# critic_optimize_info_list.append(critic_optimize_info)
306+
# critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
307+
#
308+
# return_dict = {**actor_optimize_info, **critic_optimize_info}
309+
# return return_dict

Tutorial/Tutorial3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def close(self):
120120
epoch_length=200,
121121
n_epochs=2000,
122122
total_steps=4000,
123-
batch_size=32,
123+
batch_size=128,
124124
update_times=4,
125-
update_actor_times=1,
126-
update_critic_times=2,
125+
update_actor_times=4,
126+
update_critic_times=4,
127127
gamma=0.99,
128128
epsilon=0.2,
129129
lambada=0.95

0 commit comments

Comments
 (0)