@@ -83,6 +83,7 @@ def __init__(self,
8383 fusion_flag = True
8484 break
8585 idx += 1
86+ # self.fusion_value_idx += 1
8687 if not fusion_flag :
8788 raise ValueError ('Fusion value must be in actor output. '
8889 'Please check your actor output.' )
@@ -99,6 +100,8 @@ def __init__(self,
99100 # initialize actor
100101 # self.initialize_model_weights(self.actor)
101102
103+ self ._sync_model_dict ['actor' ] = self .actor
104+
102105 # create optimizer
103106 if self .level == 0 :
104107 self .create_optimizer (name = 'actor' , optimizer = self .actor .optimizer , lr = self .actor .learning_rate )
@@ -161,6 +164,42 @@ def train_actor(self,
161164 actor_grad = tape .gradient (loss , self .actor .trainable_variables )
162165 self .actor_optimizer .apply_gradients (zip (actor_grad , self .actor .trainable_variables ))
163166
167+ # with tf.GradientTape() as tape:
168+ # out = self.resample_log_prob(actor_obs, action)
169+ # log_prob = out[0]
170+ # fusion_value = out[self.fusion_value_idx]
171+ #
172+ # ratio = tf.exp(log_prob - old_log_prob)
173+ #
174+ # actor_surrogate = tf.minimum(
175+ # ratio * advantage,
176+ # tf.clip_by_value(ratio, 1 - epsilon, 1 + epsilon) * advantage,
177+ # )
178+ #
179+ # entropy = -log_prob
180+ # fusion_value_d = tf.square(fusion_value - target)
181+ #
182+ # normalized_surrogate_loss = tf.reduce_mean(tf.math.l2_normalize(actor_surrogate, axis=0))
183+ #
184+ # normalized_entropy_loss = tf.reduce_mean(tf.math.l2_normalize(entropy, axis=0))
185+ #
186+ # normalized_fusion_value_loss = tf.reduce_mean(tf.math.l2_normalize(fusion_value_d, axis=0))
187+ #
188+ # normalized_loss = -normalized_surrogate_loss + lam * normalized_fusion_value_loss - entropy_coefficient * normalized_entropy_loss
189+ #
190+ # normalized_actor_grad = tape.gradient(normalized_loss, self.actor.trainable_variables)
191+ # self.actor_optimizer.apply_gradients(zip(normalized_actor_grad, self.actor.trainable_variables))
192+
193+ # dic = {
194+ # 'actor_surrogate_loss': tf.reduce_mean(actor_surrogate),
195+ # 'actor_loss': normalized_loss,
196+ # 'fusion_value_loss': tf.reduce_mean(fusion_value_d),
197+ # 'entropy_loss': tf.reduce_mean(entropy),
198+ # # 'normalized_actor_loss': normalized_loss,
199+ # 'normalized_actor_surrogate_loss': normalized_surrogate_loss,
200+ # 'normalized_fusion_value_loss': normalized_fusion_value_loss,
201+ # 'normalized_entropy_loss': normalized_entropy_loss,
202+ # }
164203 dic = {
165204 'actor_surrogate_loss' : actor_surrogate_loss ,
166205 'actor_loss' : loss ,
@@ -252,72 +291,142 @@ def _optimize_(self):
252291 else :
253292 for idx in self .expand_dims_idx :
254293 actor_obs [idx ] = tf .expand_dims (actor_obs [idx ], axis = 1 )
294+ info_list = []
295+ buffer_size = train_actor_input ['actor_obs' ][0 ].shape [0 ]
296+
297+ if self .hyper_parameters .batch_trajectory :
298+ critic_batch_steps = self .hyper_parameters .batch_size * train_actor_input ['actor_obs' ][0 ].shape [1 ]
299+ else :
300+ critic_batch_steps = self .hyper_parameters .batch_size
301+
302+ critic_buffer_size = self .hyper_parameters .buffer_size
255303
256304 for _ in range (self .hyper_parameters .update_times ):
257305 # fusion ppo firstly update critic
258- for _ in range (self .hyper_parameters .update_critic_times ):
259- start_index = 0
260- end_index = 0
306+ start_index = 0
307+ end_index = 0
308+ critic_start_index = 0
309+ while end_index < buffer_size :
310+ end_index = min (start_index + self .hyper_parameters .batch_size ,
311+ buffer_size )
312+ critic_end_index = min (critic_start_index + critic_batch_steps , critic_buffer_size )
261313 critic_optimize_info_list = []
314+ actor_optimize_info_list = []
315+ batch_train_actor_input = self .get_batch_data (train_actor_input , start_index , end_index )
316+ batch_train_critic_input = self .get_batch_data (train_critic_input , critic_start_index , critic_end_index )
317+ start_index = end_index
318+ critic_start_index = critic_end_index
262319 for _ in range (self .hyper_parameters .update_critic_times ):
263- while end_index < self .hyper_parameters .buffer_size :
264- end_index = min (start_index + self .hyper_parameters .batch_size ,
265- self .hyper_parameters .buffer_size )
320+ critic_optimize_info = self .train_critic (
321+ critic_obs = batch_train_critic_input ['critic_obs' ],
322+ target = batch_train_critic_input ['target' ],
323+ )
324+ critic_optimize_info_list .append (critic_optimize_info )
266325
267- batch_train_critic_input = self .get_batch_data (train_critic_input , start_index , end_index )
326+ critic_value = self .critic (* batch_train_critic_input ['critic_obs' ])
327+ critic_value_target = tf .math .reduce_mean (tf .square (critic_value - batch_train_critic_input ['target' ]))
268328
269- start_index = end_index
329+ out = self .resample_log_prob (batch_train_actor_input ['actor_obs' ], batch_train_actor_input ['action' ])
330+ fusion_value = out [self .fusion_value_idx ]
270331
271- critic_optimize_info = self .train_critic (
272- critic_obs = batch_train_critic_input ['critic_obs' ],
273- target = batch_train_critic_input ['target' ],
274- )
275- critic_optimize_info_list .append (critic_optimize_info )
276- critic_optimize_info = self .cal_average_batch_dict (critic_optimize_info_list )
332+ # fusion_value = tf.reshape(fusion_value, critic_value.shape)
333+ critic_value = tf .reshape (critic_value , shape = fusion_value .shape )
277334
278- # fusion ppo secondly update actor
279- # compute lam
280- critic_value = self .critic (* critic_obs )
281- critic_value_target = tf .reduce_mean (tf .square (critic_value - target ))
335+ fusion_value_critic = tf .math .reduce_mean (tf .square (fusion_value - critic_value ))
282336
283- out = self .resample_log_prob (actor_obs , train_actor_input ['action' ])
284- fusion_value = out [self .fusion_value_idx ]
285-
286- fusion_value = tf .reshape (fusion_value , critic_value .shape )
287-
288- fusion_value_critic = tf .reduce_mean (tf .square (fusion_value - critic_value ))
337+ # distance = tf.sqrt(critic_value_target) + tf.sqrt(fusion_value_critic)
338+ distance = critic_value_target + fusion_value_critic
289339
290- distance = tf .sqrt (critic_value_target ) + tf .sqrt (fusion_value_critic )
340+ lam = 1. / distance
341+ lam = tf .clip_by_value (lam , 0 , 0.2 )
342+ # lam = 1
343+ lam = 0
291344
292- batch_size = train_actor_input ['actor_obs' ][0 ].shape [0 ]
293-
294- lam = 1. / distance
295- # lam = 0
296- for _ in range (self .hyper_parameters .update_actor_times ):
297- start_index = 0
298- end_index = 0
299- actor_optimize_info_list = []
300345 for _ in range (self .hyper_parameters .update_actor_times ):
301- while end_index < batch_size :
302- end_index = min (start_index + self .hyper_parameters .batch_size ,
303- batch_size )
304-
305- batch_train_actor_input = self .get_batch_data (train_actor_input , start_index , end_index )
306-
307- start_index = end_index
308-
309- actor_optimize_info = self .train_actor (
310- actor_obs = batch_train_actor_input ['actor_obs' ],
311- advantage = batch_train_actor_input ['advantage' ],
312- old_log_prob = batch_train_actor_input ['old_log_prob' ],
313- action = batch_train_actor_input ['action' ],
314- target = batch_train_actor_input ['target' ],
315- lam = lam ,
316- epsilon = tf .cast (self .hyper_parameters .epsilon , dtype = tf .float32 ),
317- entropy_coefficient = tf .cast (self .hyper_parameters .entropy_coeff , dtype = tf .float32 ),
318- )
319- actor_optimize_info_list .append (actor_optimize_info )
320- actor_optimize_info = self .cal_average_batch_dict (actor_optimize_info_list )
321-
322- return_dict = {** critic_optimize_info , ** actor_optimize_info , 'lam' : lam }
323- return return_dict
346+ actor_optimize_info = self .train_actor (
347+ actor_obs = batch_train_actor_input ['actor_obs' ],
348+ advantage = batch_train_actor_input ['advantage' ],
349+ old_log_prob = batch_train_actor_input ['old_log_prob' ],
350+ action = batch_train_actor_input ['action' ],
351+ target = batch_train_actor_input ['target' ],
352+ lam = lam ,
353+ epsilon = tf .cast (self .hyper_parameters .epsilon , dtype = tf .float32 ),
354+ entropy_coefficient = tf .cast (self .hyper_parameters .entropy_coeff , dtype = tf .float32 ),
355+ )
356+ actor_optimize_info_list .append (actor_optimize_info )
357+ critic_optimize_info = self .cal_average_batch_dict (critic_optimize_info_list )
358+ actor_optimize_info = self .cal_average_batch_dict (actor_optimize_info_list )
359+ info = {** critic_optimize_info , ** actor_optimize_info , 'lam' : lam }
360+ info_list .append (info )
361+
362+ info = self .cal_average_batch_dict (info_list )
363+
364+ return info
365+
366+ # for _ in range(self.hyper_parameters.update_critic_times):
367+ # start_index = 0
368+ # end_index = 0
369+ # critic_optimize_info_list = []
370+ # for _ in range(self.hyper_parameters.update_critic_times):
371+ # while end_index < self.hyper_parameters.buffer_size:
372+ # end_index = min(start_index + self.hyper_parameters.batch_size,
373+ # self.hyper_parameters.buffer_size)
374+ #
375+ # batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)
376+ #
377+ # start_index = end_index
378+ #
379+ # critic_optimize_info = self.train_critic(
380+ # critic_obs=batch_train_critic_input['critic_obs'],
381+ # target=batch_train_critic_input['target'],
382+ # )
383+ # critic_optimize_info_list.append(critic_optimize_info)
384+ #
385+ #
386+ #
387+ # # fusion ppo secondly update actor
388+ # # compute lam
389+ # critic_value = self.critic(*critic_obs)
390+ # critic_value_target = tf.reduce_mean(tf.square(critic_value - target))
391+ #
392+ # out = self.resample_log_prob(actor_obs, train_actor_input['action'])
393+ # fusion_value = out[self.fusion_value_idx]
394+ #
395+ # fusion_value = tf.reshape(fusion_value, critic_value.shape)
396+ #
397+ # fusion_value_critic = tf.reduce_mean(tf.square(fusion_value - critic_value))
398+ #
399+ # distance = tf.sqrt(critic_value_target) + tf.sqrt(fusion_value_critic)
400+ #
401+ # batch_size = train_actor_input['actor_obs'][0].shape[0]
402+ #
403+ # lam = 1. / distance
404+ # # lam = 0
405+ # for _ in range(self.hyper_parameters.update_actor_times):
406+ # start_index = 0
407+ # end_index = 0
408+ # actor_optimize_info_list = []
409+ # for _ in range(self.hyper_parameters.update_actor_times):
410+ # while end_index < batch_size:
411+ # end_index = min(start_index + self.hyper_parameters.batch_size,
412+ # batch_size)
413+ #
414+ # batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
415+ #
416+ # start_index = end_index
417+ #
418+ # actor_optimize_info = self.train_actor(
419+ # actor_obs=batch_train_actor_input['actor_obs'],
420+ # advantage=batch_train_actor_input['advantage'],
421+ # old_log_prob=batch_train_actor_input['old_log_prob'],
422+ # action=batch_train_actor_input['action'],
423+ # target=batch_train_actor_input['target'],
424+ # lam=lam,
425+ # epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
426+ # entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
427+ # )
428+ # actor_optimize_info_list.append(actor_optimize_info)
429+ # actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)
430+ #
431+ # return_dict = {**critic_optimize_info, **actor_optimize_info, 'lam': lam}
432+ # return return_dict
0 commit comments