1
- ###
2
- # Compared with ppo_gae_continuous2, using minibatch SGD, separate actor and critic networks.
3
- ###
1
+ '''
2
+ Probabilistic Mixture-of-Experts
3
+ paper: https://arxiv.org/abs/2104.09122
4
+
5
+ Core features:
6
+ It replaces the diagonal Gaussian distribution with (differentiable) Gaussian mixture model for policy function approximation, which is more expressive.
7
+ This version is based on on-policy PPO algorithm.
8
+ '''
4
9
5
10
import gym
6
11
import torch
@@ -134,18 +139,28 @@ def get_action_and_value(self, x, action=None, select_from_mixture=True, track_g
134
139
mean , log_std , mix_coef = self .pi (x )
135
140
std = torch .exp (log_std )
136
141
normal = Normal (mean , std )
142
+
143
+ full_action = normal .sample ()
144
+ if select_from_mixture :
145
+ mix_dist = Categorical (mix_coef )
146
+ index = mix_dist .sample ()
147
+ a = full_action [index ]
148
+ else :
149
+ a = full_action
150
+
137
151
if action is None :
138
- full_action = normal .sample ()
139
- if select_from_mixture :
140
- index = Categorical (mix_coef ).sample ()
141
- action = full_action [index ]
142
- # log_prob = normal.log_prob(full_action).sum((-1, -2)) # TODO prob of other entries?
143
- log_prob = normal .log_prob (action ).sum (- 1 ) # TODO prob of other entries?
152
+ a_for_prob = a .unsqueeze (- 2 ) # to (1, action_dim), matching with mean and std (K, action_dim)
153
+ log_prob = (mix_coef @ normal .log_prob (a_for_prob ).sum (- 1 ).exp ()).log ()
154
+ else : # work for batch
155
+ a_for_prob = action .unsqueeze (- 2 ) # use given action for calculating probability
156
+ log_prob = torch .einsum ('ij,ij->i' , mix_coef , normal .log_prob (a_for_prob ).sum (- 1 ).exp ()).log () # prob of action from the whole GMM, including the mixing
157
+ a_for_prob = a_for_prob
158
+
144
159
value = self .v (x )
145
160
if track_grad :
146
- return action , log_prob , value , mix_coef
161
+ return a , log_prob , value , mix_coef
147
162
else :
148
- return action .cpu ().numpy (), log_prob , value , mix_coef
163
+ return a .cpu ().numpy (), log_prob , value , mix_coef
149
164
150
165
def put_data (self , transition ):
151
166
self .data .append (transition )
@@ -159,6 +174,7 @@ def make_batch(self,):
159
174
160
175
def train_net (self ):
161
176
s , a , r , s_prime , done_mask , logprob_a , v = self .make_batch ()
177
+ loss_list = []
162
178
with torch .no_grad ():
163
179
advantage = torch .zeros_like (r ).to (device )
164
180
lastgaelam = 0
@@ -189,20 +205,19 @@ def train_net(self):
189
205
# get mixing coefficients loss
190
206
new_a , newlogprob_a , new_vs , new_mix_coef = self .get_action_and_value (bs , ba , select_from_mixture = False , track_grad = True )
191
207
new_q = self .q_net (bs , new_a , match_shape = True )
192
- _ , best_index = new_q .max (1 )
193
- print (new_q .shape , best_index .shape , new_mix_coef .shape )
194
- coef_loss = F .mse_loss (new_mix_coef , F .one_hot (best_index , self .mix_num ).float ())
208
+ _ , best_index = new_q .max (- 1 )
209
+ coef_loss = F .mse_loss (new_mix_coef , F .one_hot (best_index , self .mix_num ).float ()).mean ()
195
210
196
211
# Q-net loss
197
- pred_q = self .q_net (bs , ba )
198
- q_loss = F .mse_loss (pred_q , btd_target )
199
-
212
+ pred_q = self .q_net (bs , ba ).squeeze ()
213
+ q_loss = F .mse_loss (pred_q , btd_target ).mean ()
200
214
201
215
new_vs = new_vs .reshape (- 1 )
202
216
ratio = torch .exp (newlogprob_a - blogprob_a ) # a/b == exp(log(a)-log(b))
203
217
surr1 = - ratio * badvantage
204
218
surr2 = - torch .clamp (ratio , 1 - eps_clip , 1 + eps_clip ) * badvantage
205
219
policy_loss = torch .max (surr1 , surr2 ).mean ()
220
+ # import pdb; pdb.set_trace()
206
221
207
222
if self .v_loss_clip : # clipped value loss
208
223
v_clipped = bv + torch .clamp (new_vs - bv , - eps_clip , eps_clip )
@@ -218,6 +233,8 @@ def train_net(self):
218
233
loss .backward ()
219
234
nn .utils .clip_grad_norm_ (self .parameters , self .max_grad_norm )
220
235
self .optimizer .step ()
236
+ loss_list = [coef_loss .item (), q_loss .item (), policy_loss .item (), value_loss .item ()]
237
+ return loss_list
221
238
222
239
def main ():
223
240
args = parse_args ()
@@ -234,18 +251,21 @@ def main():
234
251
env .seed (seed )
235
252
env .action_space .seed (seed )
236
253
env .observation_space .seed (seed )
254
+ print (env .observation_space , env .action_space )
237
255
random .seed (seed )
238
256
np .random .seed (seed )
239
257
torch .manual_seed (seed )
240
258
torch .backends .cudnn .deterministic = True
241
259
state_dim = env .observation_space .shape [0 ]
242
260
action_dim = env .action_space .shape [0 ]
243
261
hidden_dim = 64
244
- model = PMOE_PPO (state_dim , action_dim , hidden_dim )
262
+ mix_num = 5 # number of experts
263
+ model = PMOE_PPO (state_dim , action_dim , hidden_dim , mix_num )
245
264
score = 0.0
246
265
print_interval = 1
247
266
step = 0
248
267
update = 1
268
+ loss_list = []
249
269
250
270
if args .wandb_activate :
251
271
wandb .init (
@@ -277,12 +297,13 @@ def main():
277
297
# env.render()
278
298
279
299
model .put_data ((s , a , r , s_prime , logprob , v .squeeze (- 1 ), done ))
300
+
280
301
s = s_prime
281
302
282
303
score += r
283
304
284
305
if step % batch_size == 0 and step > 0 :
285
- model .train_net ()
306
+ loss_list = model .train_net ()
286
307
update += 1
287
308
eff_update = update
288
309
@@ -298,6 +319,12 @@ def main():
298
319
writer .add_scalar ("charts/episodic_return" , epi_r , n_epi )
299
320
writer .add_scalar ("charts/episodic_length" , t , n_epi )
300
321
writer .add_scalar ("charts/update" , update , n_epi )
322
+ writer .add_scalar ("charts/" , update , n_epi )
323
+ if len (loss_list ) > 0 :
324
+ writer .add_scalar ("charts/coeff_loss" , loss_list [0 ], n_epi )
325
+ writer .add_scalar ("charts/Q_loss" , loss_list [1 ], n_epi )
326
+ writer .add_scalar ("charts/policy_loss" , loss_list [2 ], n_epi )
327
+ writer .add_scalar ("charts/value_loss" , loss_list [3 ], n_epi )
301
328
302
329
score = 0.0
303
330
0 commit comments