@@ -215,51 +215,95 @@ def _optimize_(self):
215215 for idx in self .expand_dims_idx :
216216 actor_obs [idx ] = tf .expand_dims (actor_obs [idx ], axis = 1 )
217217
218+ info_list = []
219+ buffer_size = train_actor_input ['actor_obs' ][0 ].shape [0 ]
220+ critic_buffer_size = self .hyper_parameters .buffer_size
221+ critic_batch_steps = self .hyper_parameters .batch_size
222+
218223 for _ in range (self .hyper_parameters .update_times ):
219- # train actor
220- # TODO: wrap this part into a function
221- for _ in range (self .hyper_parameters .update_actor_times ):
222- start_index = 0
223- end_index = 0
224+ # fusion ppo firstly update critic
225+ start_index = 0
226+ end_index = 0
227+ critic_start_index = 0
228+ while end_index < buffer_size :
229+ end_index = min (start_index + self .hyper_parameters .batch_size ,
230+ buffer_size )
231+ critic_end_index = min (critic_start_index + critic_batch_steps , critic_buffer_size )
232+ critic_optimize_info_list = []
224233 actor_optimize_info_list = []
225- while end_index < self .hyper_parameters .buffer_size :
226- end_index = min (start_index + self .hyper_parameters .batch_size , self .hyper_parameters .buffer_size )
227-
228- batch_train_actor_input = self .get_batch_data (train_actor_input , start_index , end_index )
229-
230- start_index = end_index
234+ batch_train_actor_input = self .get_batch_data (train_actor_input , start_index , end_index )
235+ batch_train_critic_input = self .get_batch_data (train_critic_input , critic_start_index , critic_end_index )
236+ start_index = end_index
237+ critic_start_index = critic_end_index
238+ for _ in range (self .hyper_parameters .update_critic_times ):
239+ critic_optimize_info = self .train_critic (
240+ critic_obs = batch_train_critic_input ['critic_obs' ],
241+ target = batch_train_critic_input ['target' ],
242+ )
243+ critic_optimize_info_list .append (critic_optimize_info )
231244
245+ for _ in range (self .hyper_parameters .update_actor_times ):
232246 actor_optimize_info = self .train_actor (
233247 actor_obs = batch_train_actor_input ['actor_obs' ],
234248 advantage = batch_train_actor_input ['advantage' ],
235249 old_log_prob = batch_train_actor_input ['old_log_prob' ],
236250 action = batch_train_actor_input ['action' ],
237- epsilon = tf .cast (self .hyper_parameters .epsilon , dtype = tf .float32 ),
238- entropy_coefficient = tf .cast (self .hyper_parameters .entropy_coeff , dtype = tf .float32 ),
239251 )
240252 actor_optimize_info_list .append (actor_optimize_info )
253+ critic_optimize_info = self .cal_average_batch_dict (critic_optimize_info_list )
241254 actor_optimize_info = self .cal_average_batch_dict (actor_optimize_info_list )
242-
243- # train critic
244- for _ in range (self .hyper_parameters .update_critic_times ):
245- start_index = 0
246- end_index = 0
247- critic_optimize_info_list = []
248- for _ in range (self .hyper_parameters .update_critic_times ):
249- while end_index < self .hyper_parameters .buffer_size :
250- end_index = min (start_index + self .hyper_parameters .batch_size ,
251- self .hyper_parameters .buffer_size )
252-
253- batch_train_critic_input = self .get_batch_data (train_critic_input , start_index , end_index )
254-
255- start_index = end_index
256-
257- critic_optimize_info = self .train_critic (
258- critic_obs = batch_train_critic_input ['critic_obs' ],
259- target = batch_train_critic_input ['target' ],
260- )
261- critic_optimize_info_list .append (critic_optimize_info )
262- critic_optimize_info = self .cal_average_batch_dict (critic_optimize_info_list )
263-
264- return_dict = {** actor_optimize_info , ** critic_optimize_info }
265- return return_dict
255+ info = {** critic_optimize_info , ** actor_optimize_info }
256+ info_list .append (info )
257+
258+ info = self .cal_average_batch_dict (info_list )
259+
260+ return info
261+
262+ # for _ in range(self.hyper_parameters.update_times):
263+ # # train actor
264+ # # TODO: wrap this part into a function
265+ # for _ in range(self.hyper_parameters.update_actor_times):
266+ # start_index = 0
267+ # end_index = 0
268+ # actor_optimize_info_list = []
269+ # while end_index < self.hyper_parameters.buffer_size:
270+ # end_index = min(start_index + self.hyper_parameters.batch_size, self.hyper_parameters.buffer_size)
271+ #
272+ # batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
273+ #
274+ # start_index = end_index
275+ #
276+ # actor_optimize_info = self.train_actor(
277+ # actor_obs=batch_train_actor_input['actor_obs'],
278+ # advantage=batch_train_actor_input['advantage'],
279+ # old_log_prob=batch_train_actor_input['old_log_prob'],
280+ # action=batch_train_actor_input['action'],
281+ # epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
282+ # entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
283+ # )
284+ # actor_optimize_info_list.append(actor_optimize_info)
285+ # actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)
286+ #
287+ # # train critic
288+ # for _ in range(self.hyper_parameters.update_critic_times):
289+ # start_index = 0
290+ # end_index = 0
291+ # critic_optimize_info_list = []
292+ # for _ in range(self.hyper_parameters.update_critic_times):
293+ # while end_index < self.hyper_parameters.buffer_size:
294+ # end_index = min(start_index + self.hyper_parameters.batch_size,
295+ # self.hyper_parameters.buffer_size)
296+ #
297+ # batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)
298+ #
299+ # start_index = end_index
300+ #
301+ # critic_optimize_info = self.train_critic(
302+ # critic_obs=batch_train_critic_input['critic_obs'],
303+ # target=batch_train_critic_input['target'],
304+ # )
305+ # critic_optimize_info_list.append(critic_optimize_info)
306+ # critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
307+ #
308+ # return_dict = {**actor_optimize_info, **critic_optimize_info}
309+ # return return_dict
0 commit comments